From b55b779ad3062b91c64753132264a015378be506 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Fri, 24 Oct 2025 14:43:13 -0700 Subject: [PATCH] Add file size limits to linters and refactor grep_linter (#166202) - Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter - Refactor grep_linter - process files once instead of per-line - Extract allowlist check to separate function - Add 512KB limit for computing replacements, 100 match limit per file - Detect duplicate arguments - Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202 Approved by: https://github.com/Skylion007 --- .lintrunner.toml | 3 +- tools/linter/adapters/codespell_linter.py | 35 +++ tools/linter/adapters/grep_linter.py | 297 ++++++++++++++++------ tools/linter/adapters/newlines_linter.py | 30 +++ 4 files changed, 280 insertions(+), 85 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 023541e4322..26ade791a1b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -833,8 +833,7 @@ exclude_patterns = [ command = [ 'python3', 'tools/linter/adapters/grep_linter.py', - '--pattern=cudaSetDevice(', - '--pattern=cudaGetDevice(', + '--pattern=(cudaSetDevice|cudaGetDevice)\\(', '--linter-name=RAWCUDADEVICE', '--error-name=raw CUDA API usage', """--error-description=\ diff --git a/tools/linter/adapters/codespell_linter.py b/tools/linter/adapters/codespell_linter.py index ce0dd8b6692..8801f623375 100644 --- a/tools/linter/adapters/codespell_linter.py +++ b/tools/linter/adapters/codespell_linter.py @@ -20,6 +20,8 @@ FORBIDDEN_WORDS = { "multipy", # project pytorch/multipy is dead # codespell:ignore multipy } +MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes + class LintSeverity(str, Enum): ERROR = "error" @@ -86,6 +88,39 @@ def run_codespell(path: Path) -> str: def check_file(filename: str) -> list[LintMessage]: path = Path(filename).absolute() + + # Check if file is too large + try: + file_size = os.path.getsize(path) + if file_size > MAX_FILE_SIZE: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="CODESPELL", + severity=LintSeverity.WARNING, + name="file-too-large", + original=None, + replacement=None, + description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping", + ) + ] + except OSError as err: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="CODESPELL", + severity=LintSeverity.ERROR, + name="file-access-error", + original=None, + replacement=None, + description=f"Failed to get file size: {err}", + ) + ] + try: run_codespell(path) except Exception as err: diff --git a/tools/linter/adapters/grep_linter.py b/tools/linter/adapters/grep_linter.py index 62a2d2a95d6..f8c88d02409 100644 --- a/tools/linter/adapters/grep_linter.py +++ b/tools/linter/adapters/grep_linter.py @@ -16,6 +16,11 @@ from typing import NamedTuple IS_WINDOWS: bool = os.name == "nt" +MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes +MAX_MATCHES_PER_FILE: int = 100 # Maximum number of matches to report per file +MAX_ORIGINAL_SIZE: int = ( + 512 * 1024 +) # 512KB - don't compute replacement if original is larger class LintSeverity(str, Enum): @@ -25,6 +30,10 @@ class LintSeverity(str, Enum): DISABLED = "disabled" +LINTER_NAME: str = "" +ERROR_DESCRIPTION: str | None = None + + class LintMessage(NamedTuple): path: str | None line: int | None @@ -56,72 +65,143 @@ def run_command( logging.debug("took %dms", (end_time - start_time) * 1000) +def print_lint_message( + name: str, + severity: LintSeverity = LintSeverity.ERROR, + path: str | None = None, + line: int | None = None, + original: str | None = None, + replacement: str | None = None, + description: str | None = None, +) -> None: + """ + Create a LintMessage and print it as JSON. + + Accepts the same arguments as LintMessage constructor. + """ + char = None + code = LINTER_NAME + description = description or ERROR_DESCRIPTION + lint_message = LintMessage( + path, line, char, code, severity, name, original, replacement, description + ) + print(json.dumps(lint_message._asdict()), flush=True) + + +def group_lines_by_file(lines: list[str]) -> dict[str, list[str]]: + """ + Group matching lines by filename. + + Args: + lines: List of grep output lines in format "filename:line:content" + + Returns: + Dictionary mapping filename to list of line remainders (without filename prefix) + """ + grouped: dict[str, list[str]] = {} + for line in lines: + if not line: + continue + # Extract filename and remainder from "filename:line:content" format + parts = line.split(":", 1) + filename = parts[0] + remainder = parts[1] if len(parts) > 1 else "" + if filename not in grouped: + grouped[filename] = [] + grouped[filename].append(remainder) + return grouped + + +def check_allowlist( + filename: str, + allowlist_pattern: str, +) -> bool: + """ + Check if a file matches the allowlist pattern. + + Args: + filename: Path to the file to check + allowlist_pattern: Pattern to grep for in the file + + Returns: + True if the file should be skipped (allowlist pattern matched), False otherwise. + Prints error message and returns False if there was an error running grep. + """ + if not allowlist_pattern: + return False + + try: + proc = run_command(["grep", "-nEHI", allowlist_pattern, filename]) + except Exception as err: + print_lint_message( + name="command-failed", + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + return False + + # allowlist pattern was found, abort lint + if proc.returncode == 0: + return True + + return False + + def lint_file( - matching_line: str, + filename: str, + line_remainders: list[str], allowlist_pattern: str, replace_pattern: str, - linter_name: str, error_name: str, - error_description: str, -) -> LintMessage | None: - # matching_line looks like: - # tools/linter/clangtidy_linter.py:13:import foo.bar.baz - split = matching_line.split(":") - filename = split[0] +) -> None: + """ + Lint a file with one or more pattern matches, printing LintMessages as they're created. - if allowlist_pattern: - try: - proc = run_command(["grep", "-nEHI", allowlist_pattern, filename]) - except Exception as err: - return LintMessage( - path=None, - line=None, - char=None, - code=linter_name, - severity=LintSeverity.ERROR, - name="command-failed", - original=None, - replacement=None, - description=( - f"Failed due to {err.__class__.__name__}:\n{err}" - if not isinstance(err, subprocess.CalledProcessError) - else ( - "COMMAND (exit code {returncode})\n" - "{command}\n\n" - "STDERR\n{stderr}\n\n" - "STDOUT\n{stdout}" - ).format( - returncode=err.returncode, - command=" ".join(as_posix(x) for x in err.cmd), - stderr=err.stderr.decode("utf-8").strip() or "(empty)", - stdout=err.stdout.decode("utf-8").strip() or "(empty)", - ) - ), - ) + Args: + filename: Path to the file being linted + line_remainders: List of line remainders (format: "line:content" without filename prefix) + allowlist_pattern: Pattern to check for allowlisting + replace_pattern: Pattern for sed replacement + error_name: Human-readable error name + """ + if not line_remainders: + return - # allowlist pattern was found, abort lint - if proc.returncode == 0: - return None + should_skip = check_allowlist(filename, allowlist_pattern) + if should_skip: + return + # Check if file is too large to compute replacement + file_size = os.path.getsize(filename) + compute_replacement = replace_pattern and file_size <= MAX_ORIGINAL_SIZE + + # Apply replacement to entire file if pattern is specified and file is not too large original = None replacement = None - if replace_pattern: - with open(filename) as f: - original = f.read() - + if compute_replacement: + # When we have a replacement, report a single message with line=None try: + with open(filename) as f: + original = f.read() + proc = run_command(["sed", "-r", replace_pattern, filename]) replacement = proc.stdout.decode("utf-8") except Exception as err: - return LintMessage( - path=None, - line=None, - char=None, - code=linter_name, - severity=LintSeverity.ERROR, + print_lint_message( name="command-failed", - original=None, - replacement=None, description=( f"Failed due to {err.__class__.__name__}:\n{err}" if not isinstance(err, subprocess.CalledProcessError) @@ -138,18 +218,36 @@ def lint_file( ) ), ) + return - return LintMessage( - path=split[0], - line=int(split[1]) if len(split) > 1 else None, - char=None, - code=linter_name, - severity=LintSeverity.ERROR, - name=error_name, - original=original, - replacement=replacement, - description=error_description, - ) + print_lint_message( + path=filename, + name=error_name, + original=original, + replacement=replacement, + ) + else: + # When no replacement, report each matching line (up to MAX_MATCHES_PER_FILE) + total_matches = len(line_remainders) + matches_to_report = min(total_matches, MAX_MATCHES_PER_FILE) + + for line_remainder in line_remainders[:matches_to_report]: + # line_remainder format: "line_number:content" + split = line_remainder.split(":", 1) + line_number = int(split[0]) if split[0] else None + print_lint_message( + path=filename, + line=line_number, + name=error_name, + ) + + # If there are more matches than the limit, print an error + if total_matches > MAX_MATCHES_PER_FILE: + print_lint_message( + path=filename, + name="too-many-matches", + description=f"File has {total_matches} matches, only showing first {MAX_MATCHES_PER_FILE}", + ) def main() -> None: @@ -203,8 +301,24 @@ def main() -> None: nargs="+", help="paths to lint", ) + + # Check for duplicate arguments before parsing + seen_args = set() + for arg in sys.argv[1:]: + if arg.startswith("--"): + arg_name = arg.split("=")[0] + if arg_name in seen_args: + parser.error( + f"argument {arg_name}: not allowed to be specified multiple times" + ) + seen_args.add(arg_name) + args = parser.parse_args() + global LINTER_NAME, ERROR_DESCRIPTION + LINTER_NAME = args.linter_name + ERROR_DESCRIPTION = args.error_description + logging.basicConfig( format="<%(threadName)s:%(levelname)s> %(message)s", level=logging.NOTSET @@ -215,6 +329,31 @@ def main() -> None: stream=sys.stderr, ) + # Filter out files that are too large before running grep + filtered_filenames = [] + for filename in args.filenames: + try: + file_size = os.path.getsize(filename) + if file_size > MAX_FILE_SIZE: + print_lint_message( + path=filename, + severity=LintSeverity.WARNING, + name="file-too-large", + description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping", + ) + else: + filtered_filenames.append(filename) + except OSError as err: + print_lint_message( + path=filename, + name="file-access-error", + description=f"Failed to get file size: {err}", + ) + + # If all files were filtered out, nothing to do + if not filtered_filenames: + return + files_with_matches = [] if args.match_first_only: files_with_matches = ["--files-with-matches"] @@ -223,30 +362,23 @@ def main() -> None: try: # Split the grep command into multiple batches to avoid hitting the # command line length limit of ~1M on my machine - arg_length = sum(len(x) for x in args.filenames) + arg_length = sum(len(x) for x in filtered_filenames) batches = arg_length // 750000 + 1 - batch_size = len(args.filenames) // batches - for i in range(0, len(args.filenames), batch_size): + batch_size = len(filtered_filenames) // batches + for i in range(0, len(filtered_filenames), batch_size): proc = run_command( [ "grep", "-nEHI", *files_with_matches, args.pattern, - *args.filenames[i : i + batch_size], + *filtered_filenames[i : i + batch_size], ] ) lines.extend(proc.stdout.decode().splitlines()) except Exception as err: - err_msg = LintMessage( - path=None, - line=None, - char=None, - code=args.linter_name, - severity=LintSeverity.ERROR, + print_lint_message( name="command-failed", - original=None, - replacement=None, description=( f"Failed due to {err.__class__.__name__}:\n{err}" if not isinstance(err, subprocess.CalledProcessError) @@ -263,20 +395,19 @@ def main() -> None: ) ), ) - print(json.dumps(err_msg._asdict()), flush=True) sys.exit(0) - for line in lines: - lint_message = lint_file( - line, + # Group lines by file to call lint_file once per file + grouped_lines = group_lines_by_file(lines) + + for filename, line_remainders in grouped_lines.items(): + lint_file( + filename, + line_remainders, args.allowlist_pattern, args.replace_pattern, - args.linter_name, args.error_name, - args.error_description, ) - if lint_message is not None: - print(json.dumps(lint_message._asdict()), flush=True) if __name__ == "__main__": diff --git a/tools/linter/adapters/newlines_linter.py b/tools/linter/adapters/newlines_linter.py index 9af1d895699..cbd67c657b2 100644 --- a/tools/linter/adapters/newlines_linter.py +++ b/tools/linter/adapters/newlines_linter.py @@ -7,6 +7,7 @@ from __future__ import annotations import argparse import json import logging +import os import sys from enum import Enum from typing import NamedTuple @@ -15,6 +16,7 @@ from typing import NamedTuple NEWLINE = 10 # ASCII "\n" CARRIAGE_RETURN = 13 # ASCII "\r" LINTER_CODE = "NEWLINE" +MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes class LintSeverity(str, Enum): @@ -39,6 +41,34 @@ class LintMessage(NamedTuple): def check_file(filename: str) -> LintMessage | None: logging.debug("Checking file %s", filename) + # Check if file is too large + try: + file_size = os.path.getsize(filename) + if file_size > MAX_FILE_SIZE: + return LintMessage( + path=filename, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.WARNING, + name="file-too-large", + original=None, + replacement=None, + description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping", + ) + except OSError as err: + return LintMessage( + path=filename, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="file-access-error", + original=None, + replacement=None, + description=f"Failed to get file size: {err}", + ) + with open(filename, "rb") as f: lines = f.readlines()