import os import re import sys # Regular expression identifies a kernel launch indicator by # finding something approximating the pattern ">>>(arguments);" # It then requires that `TORCH_CUDA_KERNEL_LAUNCH_CHECK` be # the next command. # It allows a single backslash `\` between the end of the launch # command and the beginning of the kernel check. This handles # cases where the kernel launch is in a multiline preprocessor # definition. # # There are various ways this can fail: # * If the semicolon is in a string for some reason # * If there's a triply-nested template # But this should be sufficient to detect and fix most problem # instances and can be refined before the test is made binding kernel_launch_regex = re.compile(r""" >>> # Identifies kernel launch \s* # Maybe some whitespace (includes newlines) \([^;]+\); # And then arguments in parens and semi-colon (?! # Negative lookahead: we trigger if we don't find the launch guard \s* # Maybe some whitespace (includes newlines) \\? # 0 or 1 backslashes (for launches in preprocessor macros) (?:[0-9]+: )? # Detects and ignores a line numbering, if present \s* # Maybe some whitespace (includes newlines) TORCH_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard! ) # End negative lookahead """, flags=re.MULTILINE | re.VERBOSE) def check_code_for_cuda_kernel_launches(code, filename=None): """Checks code for CUDA kernel launches without cuda error checks. Args: filename - Filename of file containing the code. Used only for display purposes, so you can put anything here. code - The code to check Returns: The number of unsafe kernel launches in the code """ if filename is None: filename = "##Python Function Call##" # We break the code apart and put it back together to add # helpful line numberings for identifying problem areas code = enumerate(code.split("\n")) # Split by line breaks code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines code = '\n'.join(code) # Put it back together results = kernel_launch_regex.findall(code) # Search for bad launches for r in results: print(f"Missing TORCH_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{r}", file=sys.stderr) return len(results) def check_file(filename): """Checks a file for CUDA kernel launches without cuda error checks Args: filename - File to check Returns: The number of unsafe kernel launches in the file """ if not (filename.endswith(".cu") or filename.endswith(".cuh")): return 0 contents = open(filename, "r").read() return check_code_for_cuda_kernel_launches(contents, filename) def check_cuda_kernel_launches(): """Checks all pytorch code for CUDA kernel launches without cuda error checks Returns: The number of unsafe kernel launches in the codebase """ torch_dir = os.path.dirname(os.path.realpath(__file__)) torch_dir = os.path.dirname(torch_dir) # Go up to parent torch torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2 kernels_without_checks = 0 files_without_checks = [] for root, dirnames, filenames in os.walk(torch_dir): # `$BASE/build` and `$BASE/torch/include` are generated # so we don't want to flag their contents if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"): # Curtail search by modifying dirnames and filenames in place # Yes, this is the way to do this, see `help(os.walk)` dirnames[:] = [] continue for x in filenames: filename = os.path.join(root, x) file_result = check_file(filename) if file_result > 0: kernels_without_checks += file_result files_without_checks.append(filename) if kernels_without_checks > 0: count_str = f"Found {kernels_without_checks} instances in " \ f"{len(files_without_checks)} files where kernel " \ "launches didn't have checks." print(count_str, file=sys.stderr) print("Files without checks:", file=sys.stderr) for x in files_without_checks: print(f"\t{x}", file=sys.stderr) print(count_str, file=sys.stderr) return kernels_without_checks if __name__ == "__main__": unsafe_launches = check_cuda_kernel_launches() sys.exit(0)