mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Hipify revamp [REDUX] (#48715)
Summary: [Refiled version of earlier PR https://github.com/pytorch/pytorch/issues/45451] This PR revamps the hipify module in PyTorch to overcome a long list of shortcomings in the original implementation. However, these improvements are applied only when using hipify to build PyTorch extensions, not for PyTorch or Caffe2 itself. Correspondingly, changes are made to cpp_extension.py to match these improvements. The list of improvements to hipify is as follows: 1. Hipify files in the same directory as the original file, unless there's a "cuda" subdirectory in the original file path, in which case the hipified file will be in the corresponding file path with "hip" subdirectory instead of "cuda". 2. Never hipify the file in-place if changes are introduced due to hipification i.e. always ensure the hipified file either resides in a different folder or has a different filename compared to the original file. 3. Prevent re-hipification of already hipified files. This avoids creation of unnecessary "hip/hip" etc. subdirectories and additional files which have no actual use. 4. Do not write out hipified versions of files if they are identical to the original file. This results in a cleaner output directory, with minimal number of hipified files created. 5. Update header rewrite logic so that it accounts for the previous improvement. 6. Update header rewrite logic so it respects the rules for finding header files depending on whether "" or <> is used. 7. Return a dictionary of mappings of original file paths to hipified file paths from hipify function. 8. Introduce a version for hipify module to allow extensions to contain back-compatible code that targets a specific point in PyTorch where the hipify functionality changed. 9. Update cuda_to_hip_mappings.py to account for the ROCm component subdirectories inside /opt/rocm/include. This also results in cleanup of the Caffe2_HIP_INCLUDE path to remove unnecessary additions to the include path. The list of changes to cpp_extension.py is as follows: 1. Call hipify when building a CUDAExtension for ROCm. 2. Prune the list of source files to CUDAExtension to include only the hipified versions of any source files in the list (if both original and hipified versions of the source file are in the list) 3. Add subdirectories of /opt/rocm/include to the include path for extensions, so that ROCm headers for subcomponent libraries are found automatically cc jeffdaily sunway513 ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/48715 Reviewed By: bdhirsh Differential Revision: D25272824 Pulled By: ezyang fbshipit-source-id: 8bba68b27e41ca742781e1c4d7b07c6f985f040e
This commit is contained in:
parent
780f2b9a9b
commit
5f62308739
|
|
@ -1209,7 +1209,7 @@ if(USE_ROCM)
|
|||
endforeach()
|
||||
|
||||
set(Caffe2_HIP_INCLUDE
|
||||
${thrust_INCLUDE_DIRS} ${hipcub_INCLUDE_DIRS} ${rocprim_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${roctracer_INCLUDE_DIRS} ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} $<INSTALL_INTERFACE:include> ${Caffe2_HIP_INCLUDE})
|
||||
$<INSTALL_INTERFACE:include> ${Caffe2_HIP_INCLUDE})
|
||||
# This is needed for library added by hip_add_library (same for hip_add_executable)
|
||||
hip_include_directories(${Caffe2_HIP_INCLUDE})
|
||||
|
||||
|
|
|
|||
|
|
@ -205,9 +205,4 @@ if(HIP_FOUND)
|
|||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
|
||||
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
|
||||
|
||||
# Necessary includes for building PyTorch since we include HIP headers that depend on hcc/hsa headers.
|
||||
set(hcc_INCLUDE_DIRS ${HCC_PATH}/include)
|
||||
set(hsa_INCLUDE_DIRS ${HSA_PATH}/include)
|
||||
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ ext_modules = [
|
|||
extra_compile_args=CXX_FLAGS),
|
||||
]
|
||||
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.cuda', [
|
||||
'cuda_extension.cpp',
|
||||
|
|
@ -39,25 +39,9 @@ if torch.cuda.is_available() and CUDA_HOME is not None:
|
|||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
elif torch.cuda.is_available() and ROCM_HOME is not None:
|
||||
from torch.utils.hipify import hipify_python
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
hipify_python.hipify(
|
||||
project_directory=this_dir,
|
||||
output_directory=this_dir,
|
||||
includes="./*",
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,)
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.cuda', [
|
||||
'cuda_extension.cpp',
|
||||
'hip/hip_extension_kernel.hip',
|
||||
'hip/hip_extension_kernel2.hip',
|
||||
])
|
||||
ext_modules.append(extension)
|
||||
|
||||
if not IS_WINDOWS: # MSVC has bug compiling this example
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'torch_library.cu'
|
||||
|
|
@ -65,21 +49,6 @@ if not IS_WINDOWS: # MSVC has bug compiling this example
|
|||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
elif torch.cuda.is_available() and ROCM_HOME is not None:
|
||||
from torch.utils.hipify import hipify_python
|
||||
hipify_python.hipify(
|
||||
project_directory=this_dir,
|
||||
output_directory=this_dir,
|
||||
includes="./*",
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,)
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'hip/torch_library.hip'
|
||||
],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
|
|
|
|||
|
|
@ -850,6 +850,27 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
|||
kwargs['libraries'] = libraries
|
||||
|
||||
include_dirs = kwargs.get('include_dirs', [])
|
||||
|
||||
if IS_HIP_EXTENSION:
|
||||
build_dir = os.getcwd()
|
||||
if not include_dirs:
|
||||
include_dirs = ['*']
|
||||
hipify_result = hipify_python.hipify(
|
||||
project_directory=build_dir,
|
||||
output_directory=build_dir,
|
||||
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs],
|
||||
extra_files=[os.path.abspath(s) for s in sources],
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
)
|
||||
|
||||
hipified_sources = set()
|
||||
for source in sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
|
||||
|
||||
sources = list(hipified_sources)
|
||||
|
||||
include_dirs += include_paths(cuda=True)
|
||||
kwargs['include_dirs'] = include_dirs
|
||||
|
||||
|
|
@ -1698,7 +1719,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
cuda_flags += _get_rocm_arch_flags(cuda_flags)
|
||||
sources = [s if not _is_cuda_file(s) else
|
||||
os.path.abspath(os.path.join(
|
||||
path, get_hip_file_path(os.path.relpath(s, path))))
|
||||
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
|
||||
for s in sources]
|
||||
elif with_cuda:
|
||||
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .version import __version__
|
||||
|
|
@ -552,26 +552,26 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
|||
("vector_types.h", ("hip/hip_vector_types.h", CONV_INCLUDE, API_RUNTIME)),
|
||||
("cublas.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||
("cublas_v2.h", ("rocblas.h", CONV_INCLUDE_CUDA_MAIN_H, API_BLAS)),
|
||||
("curand.h", ("hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
||||
("curand_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete2.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_globals.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_lognormal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mrg32k3a.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_host.h", ("hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_kernel.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand.h", ("hiprand/hiprand.h", CONV_INCLUDE_CUDA_MAIN_H, API_RAND)),
|
||||
("curand_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_discrete2.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_globals.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_lognormal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mrg32k3a.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_host.h", ("hiprand/hiprand_mtgp32_host.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_mtgp32_kernel.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
(
|
||||
"curand_mtgp32dc_p_11213.h",
|
||||
("rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND),
|
||||
("rocrand/rocrand_mtgp32_11213.h", CONV_INCLUDE, API_RAND),
|
||||
),
|
||||
("curand_normal.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal_static.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_philox4x32_x.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_poisson.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_precalc.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_uniform.h", ("hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_normal_static.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_philox4x32_x.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_poisson.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)),
|
||||
("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
|
||||
("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)),
|
||||
|
|
@ -586,7 +586,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
|||
("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
|
||||
("nvToolsExt.h", ("roctx.h", CONV_INCLUDE, API_ROCTX)),
|
||||
("nvToolsExt.h", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,8 +34,12 @@ from . import constants
|
|||
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
|
||||
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
|
||||
|
||||
from typing import Dict, List
|
||||
from collections.abc import Mapping
|
||||
from typing import Dict, List, Iterator, Optional
|
||||
from collections.abc import Mapping, Iterable
|
||||
HipifyResult = Dict[str, Optional[str]]
|
||||
HipifyFinalResult = Dict[str, HipifyResult]
|
||||
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
|
||||
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
|
||||
|
||||
# Hardcode the PyTorch template map
|
||||
"""This dictionary provides the mapping from PyTorch kernel template types
|
||||
|
|
@ -109,14 +113,20 @@ class GeneratedFileCleaner:
|
|||
for d in self.dirs_to_clean[::-1]:
|
||||
os.rmdir(d)
|
||||
|
||||
def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), out_of_place_only=False, is_pytorch_extension=False):
|
||||
def match_extensions(filename: str, extensions: Iterable) -> bool:
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return any(filename.endswith(e) for e in extensions)
|
||||
|
||||
def matched_files_iter(
|
||||
root_path: str,
|
||||
includes: Iterable = ('*',),
|
||||
ignores: Iterable = (),
|
||||
extensions: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
is_pytorch_extension: bool = False) -> Iterator[str]:
|
||||
def _fnmatch(filepath, patterns):
|
||||
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
|
||||
|
||||
def match_extensions(filename):
|
||||
"""Helper method to see if filename ends with certain extension"""
|
||||
return any(filename.endswith(e) for e in extensions)
|
||||
|
||||
exact_matches = set(includes)
|
||||
|
||||
# This is a very rough heuristic; really, we want to avoid scanning
|
||||
|
|
@ -141,7 +151,7 @@ def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), ou
|
|||
if (
|
||||
_fnmatch(filepath, includes)
|
||||
and (not _fnmatch(filepath, ignores))
|
||||
and (match_extensions(filepath) or filepath in exact_matches)
|
||||
and (match_extensions(filepath, extensions) or filepath in exact_matches)
|
||||
):
|
||||
if not is_pytorch_extension: # for pytorch extensions, consider all files
|
||||
if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
|
||||
|
|
@ -151,14 +161,39 @@ def matched_files_iter(root_path, includes=('*',), ignores=(), extensions=(), ou
|
|||
yield filepath
|
||||
|
||||
|
||||
def preprocess_file_and_save_result(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> None:
|
||||
result = preprocessor(output_directory, filepath, all_files, includes, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
fin_path = os.path.join(output_directory, filepath)
|
||||
# Show what happened
|
||||
if show_progress:
|
||||
print(
|
||||
fin_path, "->",
|
||||
result["hipified_path"], result["status"])
|
||||
|
||||
if result["hipified_path"] is not None:
|
||||
HIPIFY_FINAL_RESULT[fin_path] = result
|
||||
|
||||
|
||||
def preprocess(
|
||||
output_directory,
|
||||
all_files,
|
||||
show_detailed=False,
|
||||
show_progress=True,
|
||||
hip_clang_launch=False,
|
||||
is_pytorch_extension=False,
|
||||
clean_ctx=None):
|
||||
output_directory: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
show_detailed: bool = False,
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
clean_ctx: GeneratedFileCleaner = None) -> HipifyFinalResult:
|
||||
"""
|
||||
Call preprocessor on selected files.
|
||||
|
||||
|
|
@ -173,13 +208,8 @@ def preprocess(
|
|||
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
|
||||
|
||||
for filepath in all_files:
|
||||
result = preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx)
|
||||
|
||||
# Show what happened
|
||||
if show_progress:
|
||||
print(
|
||||
filepath, "->",
|
||||
get_hip_file_path(filepath), result)
|
||||
preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats,
|
||||
hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
|
||||
|
||||
print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
|
||||
|
||||
|
|
@ -187,6 +217,8 @@ def preprocess(
|
|||
if show_detailed:
|
||||
compute_stats(stats)
|
||||
|
||||
return HIPIFY_FINAL_RESULT
|
||||
|
||||
|
||||
def compute_stats(stats):
|
||||
unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
|
||||
|
|
@ -477,13 +509,13 @@ def replace_extern_shared(input_string):
|
|||
return output_string
|
||||
|
||||
|
||||
def get_hip_file_path(filepath):
|
||||
def get_hip_file_path(filepath, is_pytorch_extension=False):
|
||||
"""
|
||||
Returns the new name of the hipified file
|
||||
"""
|
||||
# At the moment, some files are HIPified in place. The predicate
|
||||
# At the moment, some PyTorch source files are HIPified in place. The predicate
|
||||
# is_out_of_place tells us if this is the case or not.
|
||||
if not is_out_of_place(filepath):
|
||||
if not is_pytorch_extension and not is_out_of_place(filepath):
|
||||
return filepath
|
||||
|
||||
dirpath, filename = os.path.split(filepath)
|
||||
|
|
@ -492,10 +524,8 @@ def get_hip_file_path(filepath):
|
|||
# Here's the plan:
|
||||
#
|
||||
# In general, we need to disambiguate the HIPified filename so that
|
||||
# it gets a different name from the original Caffe2 filename, so
|
||||
# that we don't overwrite the original file. (Additionally,
|
||||
# hcc historically had a bug where if you had two files with
|
||||
# the same basename, they would clobber each other.)
|
||||
# it gets a different name from the original filename, so
|
||||
# that we don't overwrite the original file
|
||||
#
|
||||
# There's a lot of different naming conventions across PyTorch
|
||||
# and Caffe2, but the general recipe is to convert occurrences
|
||||
|
|
@ -509,12 +539,18 @@ def get_hip_file_path(filepath):
|
|||
#
|
||||
# - If the file name contains "CUDA", replace it with "HIP", AND
|
||||
#
|
||||
# If NONE of the above occurred, then insert "hip" in the file path
|
||||
# as the direct parent folder of the file
|
||||
# - ALWAYS replace '.cu' with '.hip', because those files
|
||||
# contain CUDA kernels that needs to be hipified and processed with
|
||||
# hip compiler
|
||||
#
|
||||
# Furthermore, ALWAYS replace '.cu' with '.hip', because those files
|
||||
# contain CUDA kernels that needs to be hipified and processed with
|
||||
# hcc compiler
|
||||
# - If we are not hipifying a PyTorch extension, and the parent
|
||||
# directory name did not change as a result of the above
|
||||
# transformations, insert "hip" in the file path
|
||||
# as the direct parent folder of the file
|
||||
#
|
||||
# - If we are hipifying a PyTorch extension, and the parent directory
|
||||
# name as well as the filename (incl. extension) did not change as
|
||||
# a result of the above transformations, insert "_hip" in the filename
|
||||
#
|
||||
# This isn't set in stone; we might adjust this to support other
|
||||
# naming conventions.
|
||||
|
|
@ -522,6 +558,7 @@ def get_hip_file_path(filepath):
|
|||
if ext == '.cu':
|
||||
ext = '.hip'
|
||||
|
||||
orig_filename = filename
|
||||
orig_dirpath = dirpath
|
||||
|
||||
dirpath = dirpath.replace('cuda', 'hip')
|
||||
|
|
@ -533,9 +570,12 @@ def get_hip_file_path(filepath):
|
|||
if dirpath != "caffe2/core":
|
||||
root = root.replace('THC', 'THH')
|
||||
|
||||
if dirpath == orig_dirpath:
|
||||
if not is_pytorch_extension and dirpath == orig_dirpath:
|
||||
dirpath = os.path.join(dirpath, 'hip')
|
||||
|
||||
if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
|
||||
root = root + "_hip"
|
||||
|
||||
return os.path.join(dirpath, root + ext)
|
||||
|
||||
|
||||
|
|
@ -653,13 +693,35 @@ RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
|
|||
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
|
||||
RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
|
||||
|
||||
def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch_extension, clean_ctx):
|
||||
"""
|
||||
Returns a dict with the following keys:
|
||||
"hipified_path" : absolute path of hipified source file
|
||||
"status" : "ok" if hipified file was written out
|
||||
"skipped" if an identical hipified file already existed
|
||||
"ignored" if the source file was a hipified file itself
|
||||
"""
|
||||
def preprocessor(
|
||||
output_directory: str,
|
||||
filepath: str,
|
||||
all_files: Iterable,
|
||||
includes: Iterable,
|
||||
stats: Dict[str, List],
|
||||
hip_clang_launch: bool,
|
||||
is_pytorch_extension: bool,
|
||||
clean_ctx: GeneratedFileCleaner,
|
||||
show_progress: bool) -> HipifyResult:
|
||||
""" Executes the CUDA -> HIP conversion on the specified file. """
|
||||
fin_path = os.path.join(output_directory, filepath)
|
||||
|
||||
with open(fin_path, 'r', encoding='utf-8') as fin:
|
||||
if fin.readline() == HIPIFY_C_BREADCRUMB:
|
||||
return {"hipified_path": None, "status": "ignored"}
|
||||
fin.seek(0)
|
||||
output_source = fin.read()
|
||||
|
||||
fout_path = os.path.join(output_directory, get_hip_file_path(filepath))
|
||||
orig_output_source = output_source
|
||||
|
||||
fout_path = os.path.join(output_directory, get_hip_file_path(filepath, is_pytorch_extension))
|
||||
if not os.path.exists(os.path.dirname(fout_path)):
|
||||
clean_ctx.makedirs(os.path.dirname(fout_path))
|
||||
|
||||
|
|
@ -678,9 +740,10 @@ def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch
|
|||
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
|
||||
|
||||
# Header rewrites
|
||||
def mk_repl(templ):
|
||||
def mk_repl(templ, include_current_dir=True):
|
||||
def repl(m):
|
||||
f = m.group(1)
|
||||
dirpath, filename = os.path.split(f)
|
||||
if (
|
||||
f.startswith("ATen/cuda")
|
||||
or f.startswith("ATen/native/cuda")
|
||||
|
|
@ -690,11 +753,41 @@ def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch
|
|||
or f.startswith("THCUNN/")
|
||||
or (f.startswith("THC") and not f.startswith("THCP"))
|
||||
):
|
||||
return templ.format(get_hip_file_path(m.group(1)))
|
||||
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
|
||||
# if filename is one of the files being hipified for this extension
|
||||
if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
|
||||
header_dir = None
|
||||
header_filepath = None
|
||||
# If include_current_dir True, look first in same dir as the including source file
|
||||
if include_current_dir:
|
||||
header_dir_to_check = os.path.dirname(fin_path)
|
||||
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
|
||||
if os.path.exists(header_path_to_check):
|
||||
header_dir = header_dir_to_check
|
||||
header_filepath = header_path_to_check
|
||||
# If not found, look in include dirs one by one and first match wins
|
||||
if header_filepath is None:
|
||||
for include in includes:
|
||||
header_dir_to_check = os.path.join(output_directory, os.path.dirname(include))
|
||||
header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
|
||||
if os.path.exists(header_path_to_check):
|
||||
header_dir = header_dir_to_check
|
||||
header_filepath = header_path_to_check
|
||||
# If header file not found, keep as is
|
||||
if header_filepath is None:
|
||||
return m.group(0)
|
||||
# Hipify header file first if needed
|
||||
if header_filepath not in HIPIFY_FINAL_RESULT:
|
||||
preprocess_file_and_save_result(output_directory,
|
||||
os.path.relpath(header_filepath, output_directory),
|
||||
all_files, includes, stats, hip_clang_launch, is_pytorch_extension,
|
||||
clean_ctx, show_progress)
|
||||
return templ.format(os.path.relpath(HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"], header_dir))
|
||||
|
||||
return m.group(0)
|
||||
return repl
|
||||
output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"'), output_source)
|
||||
output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>'), output_source)
|
||||
output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
|
||||
output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
|
||||
output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
|
||||
|
||||
# CMakeLists.txt rewrites
|
||||
|
|
@ -717,6 +810,18 @@ def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch
|
|||
# Replace the extern __shared__
|
||||
output_source = replace_extern_shared(output_source)
|
||||
|
||||
# Don't write out identical hipified files for extensions if dirpath has not changed
|
||||
if (
|
||||
is_pytorch_extension
|
||||
and orig_output_source == output_source
|
||||
and os.path.dirname(fin_path) == os.path.dirname(fout_path)
|
||||
):
|
||||
return {"hipified_path": fin_path, "status": "ok"}
|
||||
|
||||
# Add hipify breadcrumb for C-style files to avoid re-hipification
|
||||
if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
|
||||
output_source = HIPIFY_C_BREADCRUMB + output_source
|
||||
|
||||
do_write = True
|
||||
if os.path.exists(fout_path):
|
||||
with open(fout_path, 'r', encoding='utf-8') as fout_old:
|
||||
|
|
@ -724,9 +829,9 @@ def preprocessor(output_directory, filepath, stats, hip_clang_launch, is_pytorch
|
|||
if do_write:
|
||||
with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
|
||||
fout.write(output_source)
|
||||
return "ok"
|
||||
return {"hipified_path": fout_path, "status": "ok"}
|
||||
else:
|
||||
return "skipped"
|
||||
return {"hipified_path": fout_path, "status": "skipped"}
|
||||
|
||||
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
|
||||
with openf(filepath, "r+") as f:
|
||||
|
|
@ -818,19 +923,19 @@ def str2bool(v):
|
|||
|
||||
|
||||
def hipify(
|
||||
project_directory,
|
||||
show_detailed=False,
|
||||
extensions=(".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
|
||||
output_directory="",
|
||||
includes=(),
|
||||
extra_files=(),
|
||||
out_of_place_only=False,
|
||||
ignores=(),
|
||||
show_progress=True,
|
||||
hip_clang_launch=False,
|
||||
is_pytorch_extension=False,
|
||||
clean_ctx=None
|
||||
):
|
||||
project_directory: str,
|
||||
show_detailed: bool = False,
|
||||
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
|
||||
output_directory: str = "",
|
||||
includes: Iterable = (),
|
||||
extra_files: Iterable = (),
|
||||
out_of_place_only: bool = False,
|
||||
ignores: Iterable = (),
|
||||
show_progress: bool = True,
|
||||
hip_clang_launch: bool = False,
|
||||
is_pytorch_extension: bool = False,
|
||||
clean_ctx: GeneratedFileCleaner = None
|
||||
) -> HipifyFinalResult:
|
||||
if project_directory == "":
|
||||
project_directory = os.getcwd()
|
||||
|
||||
|
|
@ -853,12 +958,17 @@ def hipify(
|
|||
out_of_place_only=out_of_place_only,
|
||||
is_pytorch_extension=is_pytorch_extension))
|
||||
all_files_set = set(all_files)
|
||||
all_files += [f for f in extra_files if f not in all_files_set]
|
||||
# Convert extra_files to relative paths since all_files has all relative paths
|
||||
for f in extra_files:
|
||||
f_rel = os.path.relpath(f, output_directory)
|
||||
if f_rel not in all_files_set:
|
||||
all_files.append(f_rel)
|
||||
|
||||
# Start Preprocessor
|
||||
preprocess(
|
||||
return preprocess(
|
||||
output_directory,
|
||||
all_files,
|
||||
includes,
|
||||
show_detailed=show_detailed,
|
||||
show_progress=show_progress,
|
||||
hip_clang_launch=hip_clang_launch,
|
||||
|
|
|
|||
1
torch/utils/hipify/version.py
Normal file
1
torch/utils/hipify/version.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
__version__ = '1.0.0'
|
||||
Loading…
Reference in New Issue
Block a user