Add file size limits to linters and refactor grep_linter (#166202)

- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter
- Refactor grep_linter
  - process files once instead of per-line
  - Extract allowlist check to separate function
  - Add 512KB limit for computing replacements, 100 match limit per file
  - Detect duplicate arguments
- Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-10-24 14:43:13 -07:00 committed by PyTorch MergeBot
parent 74e53d0761
commit b55b779ad3
4 changed files with 280 additions and 85 deletions

View File

@ -833,8 +833,7 @@ exclude_patterns = [
command = [ command = [
'python3', 'python3',
'tools/linter/adapters/grep_linter.py', 'tools/linter/adapters/grep_linter.py',
'--pattern=cudaSetDevice(', '--pattern=(cudaSetDevice|cudaGetDevice)\\(',
'--pattern=cudaGetDevice(',
'--linter-name=RAWCUDADEVICE', '--linter-name=RAWCUDADEVICE',
'--error-name=raw CUDA API usage', '--error-name=raw CUDA API usage',
"""--error-description=\ """--error-description=\

View File

@ -20,6 +20,8 @@ FORBIDDEN_WORDS = {
"multipy", # project pytorch/multipy is dead # codespell:ignore multipy "multipy", # project pytorch/multipy is dead # codespell:ignore multipy
} }
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
class LintSeverity(str, Enum): class LintSeverity(str, Enum):
ERROR = "error" ERROR = "error"
@ -86,6 +88,39 @@ def run_codespell(path: Path) -> str:
def check_file(filename: str) -> list[LintMessage]: def check_file(filename: str) -> list[LintMessage]:
path = Path(filename).absolute() path = Path(filename).absolute()
# Check if file is too large
try:
file_size = os.path.getsize(path)
if file_size > MAX_FILE_SIZE:
return [
LintMessage(
path=filename,
line=None,
char=None,
code="CODESPELL",
severity=LintSeverity.WARNING,
name="file-too-large",
original=None,
replacement=None,
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
]
except OSError as err:
return [
LintMessage(
path=filename,
line=None,
char=None,
code="CODESPELL",
severity=LintSeverity.ERROR,
name="file-access-error",
original=None,
replacement=None,
description=f"Failed to get file size: {err}",
)
]
try: try:
run_codespell(path) run_codespell(path)
except Exception as err: except Exception as err:

View File

@ -16,6 +16,11 @@ from typing import NamedTuple
IS_WINDOWS: bool = os.name == "nt" IS_WINDOWS: bool = os.name == "nt"
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
MAX_MATCHES_PER_FILE: int = 100 # Maximum number of matches to report per file
MAX_ORIGINAL_SIZE: int = (
512 * 1024
) # 512KB - don't compute replacement if original is larger
class LintSeverity(str, Enum): class LintSeverity(str, Enum):
@ -25,6 +30,10 @@ class LintSeverity(str, Enum):
DISABLED = "disabled" DISABLED = "disabled"
LINTER_NAME: str = ""
ERROR_DESCRIPTION: str | None = None
class LintMessage(NamedTuple): class LintMessage(NamedTuple):
path: str | None path: str | None
line: int | None line: int | None
@ -56,32 +65,76 @@ def run_command(
logging.debug("took %dms", (end_time - start_time) * 1000) logging.debug("took %dms", (end_time - start_time) * 1000)
def lint_file( def print_lint_message(
matching_line: str, name: str,
allowlist_pattern: str, severity: LintSeverity = LintSeverity.ERROR,
replace_pattern: str, path: str | None = None,
linter_name: str, line: int | None = None,
error_name: str, original: str | None = None,
error_description: str, replacement: str | None = None,
) -> LintMessage | None: description: str | None = None,
# matching_line looks like: ) -> None:
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz """
split = matching_line.split(":") Create a LintMessage and print it as JSON.
filename = split[0]
Accepts the same arguments as LintMessage constructor.
"""
char = None
code = LINTER_NAME
description = description or ERROR_DESCRIPTION
lint_message = LintMessage(
path, line, char, code, severity, name, original, replacement, description
)
print(json.dumps(lint_message._asdict()), flush=True)
def group_lines_by_file(lines: list[str]) -> dict[str, list[str]]:
"""
Group matching lines by filename.
Args:
lines: List of grep output lines in format "filename:line:content"
Returns:
Dictionary mapping filename to list of line remainders (without filename prefix)
"""
grouped: dict[str, list[str]] = {}
for line in lines:
if not line:
continue
# Extract filename and remainder from "filename:line:content" format
parts = line.split(":", 1)
filename = parts[0]
remainder = parts[1] if len(parts) > 1 else ""
if filename not in grouped:
grouped[filename] = []
grouped[filename].append(remainder)
return grouped
def check_allowlist(
filename: str,
allowlist_pattern: str,
) -> bool:
"""
Check if a file matches the allowlist pattern.
Args:
filename: Path to the file to check
allowlist_pattern: Pattern to grep for in the file
Returns:
True if the file should be skipped (allowlist pattern matched), False otherwise.
Prints error message and returns False if there was an error running grep.
"""
if not allowlist_pattern:
return False
if allowlist_pattern:
try: try:
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename]) proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
except Exception as err: except Exception as err:
return LintMessage( print_lint_message(
path=None,
line=None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
name="command-failed", name="command-failed",
original=None,
replacement=None,
description=( description=(
f"Failed due to {err.__class__.__name__}:\n{err}" f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError) if not isinstance(err, subprocess.CalledProcessError)
@ -98,30 +151,57 @@ def lint_file(
) )
), ),
) )
return False
# allowlist pattern was found, abort lint # allowlist pattern was found, abort lint
if proc.returncode == 0: if proc.returncode == 0:
return None return True
return False
def lint_file(
filename: str,
line_remainders: list[str],
allowlist_pattern: str,
replace_pattern: str,
error_name: str,
) -> None:
"""
Lint a file with one or more pattern matches, printing LintMessages as they're created.
Args:
filename: Path to the file being linted
line_remainders: List of line remainders (format: "line:content" without filename prefix)
allowlist_pattern: Pattern to check for allowlisting
replace_pattern: Pattern for sed replacement
error_name: Human-readable error name
"""
if not line_remainders:
return
should_skip = check_allowlist(filename, allowlist_pattern)
if should_skip:
return
# Check if file is too large to compute replacement
file_size = os.path.getsize(filename)
compute_replacement = replace_pattern and file_size <= MAX_ORIGINAL_SIZE
# Apply replacement to entire file if pattern is specified and file is not too large
original = None original = None
replacement = None replacement = None
if replace_pattern: if compute_replacement:
# When we have a replacement, report a single message with line=None
try:
with open(filename) as f: with open(filename) as f:
original = f.read() original = f.read()
try:
proc = run_command(["sed", "-r", replace_pattern, filename]) proc = run_command(["sed", "-r", replace_pattern, filename])
replacement = proc.stdout.decode("utf-8") replacement = proc.stdout.decode("utf-8")
except Exception as err: except Exception as err:
return LintMessage( print_lint_message(
path=None,
line=None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
name="command-failed", name="command-failed",
original=None,
replacement=None,
description=( description=(
f"Failed due to {err.__class__.__name__}:\n{err}" f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError) if not isinstance(err, subprocess.CalledProcessError)
@ -138,17 +218,35 @@ def lint_file(
) )
), ),
) )
return
return LintMessage( print_lint_message(
path=split[0], path=filename,
line=int(split[1]) if len(split) > 1 else None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
name=error_name, name=error_name,
original=original, original=original,
replacement=replacement, replacement=replacement,
description=error_description, )
else:
# When no replacement, report each matching line (up to MAX_MATCHES_PER_FILE)
total_matches = len(line_remainders)
matches_to_report = min(total_matches, MAX_MATCHES_PER_FILE)
for line_remainder in line_remainders[:matches_to_report]:
# line_remainder format: "line_number:content"
split = line_remainder.split(":", 1)
line_number = int(split[0]) if split[0] else None
print_lint_message(
path=filename,
line=line_number,
name=error_name,
)
# If there are more matches than the limit, print an error
if total_matches > MAX_MATCHES_PER_FILE:
print_lint_message(
path=filename,
name="too-many-matches",
description=f"File has {total_matches} matches, only showing first {MAX_MATCHES_PER_FILE}",
) )
@ -203,8 +301,24 @@ def main() -> None:
nargs="+", nargs="+",
help="paths to lint", help="paths to lint",
) )
# Check for duplicate arguments before parsing
seen_args = set()
for arg in sys.argv[1:]:
if arg.startswith("--"):
arg_name = arg.split("=")[0]
if arg_name in seen_args:
parser.error(
f"argument {arg_name}: not allowed to be specified multiple times"
)
seen_args.add(arg_name)
args = parser.parse_args() args = parser.parse_args()
global LINTER_NAME, ERROR_DESCRIPTION
LINTER_NAME = args.linter_name
ERROR_DESCRIPTION = args.error_description
logging.basicConfig( logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s", format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.NOTSET level=logging.NOTSET
@ -215,6 +329,31 @@ def main() -> None:
stream=sys.stderr, stream=sys.stderr,
) )
# Filter out files that are too large before running grep
filtered_filenames = []
for filename in args.filenames:
try:
file_size = os.path.getsize(filename)
if file_size > MAX_FILE_SIZE:
print_lint_message(
path=filename,
severity=LintSeverity.WARNING,
name="file-too-large",
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
else:
filtered_filenames.append(filename)
except OSError as err:
print_lint_message(
path=filename,
name="file-access-error",
description=f"Failed to get file size: {err}",
)
# If all files were filtered out, nothing to do
if not filtered_filenames:
return
files_with_matches = [] files_with_matches = []
if args.match_first_only: if args.match_first_only:
files_with_matches = ["--files-with-matches"] files_with_matches = ["--files-with-matches"]
@ -223,30 +362,23 @@ def main() -> None:
try: try:
# Split the grep command into multiple batches to avoid hitting the # Split the grep command into multiple batches to avoid hitting the
# command line length limit of ~1M on my machine # command line length limit of ~1M on my machine
arg_length = sum(len(x) for x in args.filenames) arg_length = sum(len(x) for x in filtered_filenames)
batches = arg_length // 750000 + 1 batches = arg_length // 750000 + 1
batch_size = len(args.filenames) // batches batch_size = len(filtered_filenames) // batches
for i in range(0, len(args.filenames), batch_size): for i in range(0, len(filtered_filenames), batch_size):
proc = run_command( proc = run_command(
[ [
"grep", "grep",
"-nEHI", "-nEHI",
*files_with_matches, *files_with_matches,
args.pattern, args.pattern,
*args.filenames[i : i + batch_size], *filtered_filenames[i : i + batch_size],
] ]
) )
lines.extend(proc.stdout.decode().splitlines()) lines.extend(proc.stdout.decode().splitlines())
except Exception as err: except Exception as err:
err_msg = LintMessage( print_lint_message(
path=None,
line=None,
char=None,
code=args.linter_name,
severity=LintSeverity.ERROR,
name="command-failed", name="command-failed",
original=None,
replacement=None,
description=( description=(
f"Failed due to {err.__class__.__name__}:\n{err}" f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError) if not isinstance(err, subprocess.CalledProcessError)
@ -263,20 +395,19 @@ def main() -> None:
) )
), ),
) )
print(json.dumps(err_msg._asdict()), flush=True)
sys.exit(0) sys.exit(0)
for line in lines: # Group lines by file to call lint_file once per file
lint_message = lint_file( grouped_lines = group_lines_by_file(lines)
line,
for filename, line_remainders in grouped_lines.items():
lint_file(
filename,
line_remainders,
args.allowlist_pattern, args.allowlist_pattern,
args.replace_pattern, args.replace_pattern,
args.linter_name,
args.error_name, args.error_name,
args.error_description,
) )
if lint_message is not None:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,6 +7,7 @@ from __future__ import annotations
import argparse import argparse
import json import json
import logging import logging
import os
import sys import sys
from enum import Enum from enum import Enum
from typing import NamedTuple from typing import NamedTuple
@ -15,6 +16,7 @@ from typing import NamedTuple
NEWLINE = 10 # ASCII "\n" NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r" CARRIAGE_RETURN = 13 # ASCII "\r"
LINTER_CODE = "NEWLINE" LINTER_CODE = "NEWLINE"
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
class LintSeverity(str, Enum): class LintSeverity(str, Enum):
@ -39,6 +41,34 @@ class LintMessage(NamedTuple):
def check_file(filename: str) -> LintMessage | None: def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename) logging.debug("Checking file %s", filename)
# Check if file is too large
try:
file_size = os.path.getsize(filename)
if file_size > MAX_FILE_SIZE:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="file-too-large",
original=None,
replacement=None,
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
except OSError as err:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="file-access-error",
original=None,
replacement=None,
description=f"Failed to get file size: {err}",
)
with open(filename, "rb") as f: with open(filename, "rb") as f:
lines = f.readlines() lines = f.readlines()