Add file size limits to linters and refactor grep_linter (#166202)

- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter
- Refactor grep_linter
  - process files once instead of per-line
  - Extract allowlist check to separate function
  - Add 512KB limit for computing replacements, 100 match limit per file
  - Detect duplicate arguments
- Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-10-24 14:43:13 -07:00 committed by PyTorch MergeBot
parent 74e53d0761
commit b55b779ad3
4 changed files with 280 additions and 85 deletions

View File

@ -833,8 +833,7 @@ exclude_patterns = [
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
'--pattern=cudaSetDevice(',
'--pattern=cudaGetDevice(',
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
'--linter-name=RAWCUDADEVICE',
'--error-name=raw CUDA API usage',
"""--error-description=\

View File

@ -20,6 +20,8 @@ FORBIDDEN_WORDS = {
"multipy", # project pytorch/multipy is dead # codespell:ignore multipy
}
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
class LintSeverity(str, Enum):
ERROR = "error"
@ -86,6 +88,39 @@ def run_codespell(path: Path) -> str:
def check_file(filename: str) -> list[LintMessage]:
path = Path(filename).absolute()
# Check if file is too large
try:
file_size = os.path.getsize(path)
if file_size > MAX_FILE_SIZE:
return [
LintMessage(
path=filename,
line=None,
char=None,
code="CODESPELL",
severity=LintSeverity.WARNING,
name="file-too-large",
original=None,
replacement=None,
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
]
except OSError as err:
return [
LintMessage(
path=filename,
line=None,
char=None,
code="CODESPELL",
severity=LintSeverity.ERROR,
name="file-access-error",
original=None,
replacement=None,
description=f"Failed to get file size: {err}",
)
]
try:
run_codespell(path)
except Exception as err:

View File

@ -16,6 +16,11 @@ from typing import NamedTuple
IS_WINDOWS: bool = os.name == "nt"
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
MAX_MATCHES_PER_FILE: int = 100 # Maximum number of matches to report per file
MAX_ORIGINAL_SIZE: int = (
512 * 1024
) # 512KB - don't compute replacement if original is larger
class LintSeverity(str, Enum):
@ -25,6 +30,10 @@ class LintSeverity(str, Enum):
DISABLED = "disabled"
LINTER_NAME: str = ""
ERROR_DESCRIPTION: str | None = None
class LintMessage(NamedTuple):
path: str | None
line: int | None
@ -56,32 +65,76 @@ def run_command(
logging.debug("took %dms", (end_time - start_time) * 1000)
def lint_file(
matching_line: str,
allowlist_pattern: str,
replace_pattern: str,
linter_name: str,
error_name: str,
error_description: str,
) -> LintMessage | None:
# matching_line looks like:
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
split = matching_line.split(":")
filename = split[0]
def print_lint_message(
name: str,
severity: LintSeverity = LintSeverity.ERROR,
path: str | None = None,
line: int | None = None,
original: str | None = None,
replacement: str | None = None,
description: str | None = None,
) -> None:
"""
Create a LintMessage and print it as JSON.
Accepts the same arguments as LintMessage constructor.
"""
char = None
code = LINTER_NAME
description = description or ERROR_DESCRIPTION
lint_message = LintMessage(
path, line, char, code, severity, name, original, replacement, description
)
print(json.dumps(lint_message._asdict()), flush=True)
def group_lines_by_file(lines: list[str]) -> dict[str, list[str]]:
"""
Group matching lines by filename.
Args:
lines: List of grep output lines in format "filename:line:content"
Returns:
Dictionary mapping filename to list of line remainders (without filename prefix)
"""
grouped: dict[str, list[str]] = {}
for line in lines:
if not line:
continue
# Extract filename and remainder from "filename:line:content" format
parts = line.split(":", 1)
filename = parts[0]
remainder = parts[1] if len(parts) > 1 else ""
if filename not in grouped:
grouped[filename] = []
grouped[filename].append(remainder)
return grouped
def check_allowlist(
filename: str,
allowlist_pattern: str,
) -> bool:
"""
Check if a file matches the allowlist pattern.
Args:
filename: Path to the file to check
allowlist_pattern: Pattern to grep for in the file
Returns:
True if the file should be skipped (allowlist pattern matched), False otherwise.
Prints error message and returns False if there was an error running grep.
"""
if not allowlist_pattern:
return False
if allowlist_pattern:
try:
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
except Exception as err:
return LintMessage(
path=None,
line=None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
print_lint_message(
name="command-failed",
original=None,
replacement=None,
description=(
f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError)
@ -98,30 +151,57 @@ def lint_file(
)
),
)
return False
# allowlist pattern was found, abort lint
if proc.returncode == 0:
return None
return True
return False
def lint_file(
filename: str,
line_remainders: list[str],
allowlist_pattern: str,
replace_pattern: str,
error_name: str,
) -> None:
"""
Lint a file with one or more pattern matches, printing LintMessages as they're created.
Args:
filename: Path to the file being linted
line_remainders: List of line remainders (format: "line:content" without filename prefix)
allowlist_pattern: Pattern to check for allowlisting
replace_pattern: Pattern for sed replacement
error_name: Human-readable error name
"""
if not line_remainders:
return
should_skip = check_allowlist(filename, allowlist_pattern)
if should_skip:
return
# Check if file is too large to compute replacement
file_size = os.path.getsize(filename)
compute_replacement = replace_pattern and file_size <= MAX_ORIGINAL_SIZE
# Apply replacement to entire file if pattern is specified and file is not too large
original = None
replacement = None
if replace_pattern:
if compute_replacement:
# When we have a replacement, report a single message with line=None
try:
with open(filename) as f:
original = f.read()
try:
proc = run_command(["sed", "-r", replace_pattern, filename])
replacement = proc.stdout.decode("utf-8")
except Exception as err:
return LintMessage(
path=None,
line=None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
print_lint_message(
name="command-failed",
original=None,
replacement=None,
description=(
f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError)
@ -138,17 +218,35 @@ def lint_file(
)
),
)
return
return LintMessage(
path=split[0],
line=int(split[1]) if len(split) > 1 else None,
char=None,
code=linter_name,
severity=LintSeverity.ERROR,
print_lint_message(
path=filename,
name=error_name,
original=original,
replacement=replacement,
description=error_description,
)
else:
# When no replacement, report each matching line (up to MAX_MATCHES_PER_FILE)
total_matches = len(line_remainders)
matches_to_report = min(total_matches, MAX_MATCHES_PER_FILE)
for line_remainder in line_remainders[:matches_to_report]:
# line_remainder format: "line_number:content"
split = line_remainder.split(":", 1)
line_number = int(split[0]) if split[0] else None
print_lint_message(
path=filename,
line=line_number,
name=error_name,
)
# If there are more matches than the limit, print an error
if total_matches > MAX_MATCHES_PER_FILE:
print_lint_message(
path=filename,
name="too-many-matches",
description=f"File has {total_matches} matches, only showing first {MAX_MATCHES_PER_FILE}",
)
@ -203,8 +301,24 @@ def main() -> None:
nargs="+",
help="paths to lint",
)
# Check for duplicate arguments before parsing
seen_args = set()
for arg in sys.argv[1:]:
if arg.startswith("--"):
arg_name = arg.split("=")[0]
if arg_name in seen_args:
parser.error(
f"argument {arg_name}: not allowed to be specified multiple times"
)
seen_args.add(arg_name)
args = parser.parse_args()
global LINTER_NAME, ERROR_DESCRIPTION
LINTER_NAME = args.linter_name
ERROR_DESCRIPTION = args.error_description
logging.basicConfig(
format="<%(threadName)s:%(levelname)s> %(message)s",
level=logging.NOTSET
@ -215,6 +329,31 @@ def main() -> None:
stream=sys.stderr,
)
# Filter out files that are too large before running grep
filtered_filenames = []
for filename in args.filenames:
try:
file_size = os.path.getsize(filename)
if file_size > MAX_FILE_SIZE:
print_lint_message(
path=filename,
severity=LintSeverity.WARNING,
name="file-too-large",
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
else:
filtered_filenames.append(filename)
except OSError as err:
print_lint_message(
path=filename,
name="file-access-error",
description=f"Failed to get file size: {err}",
)
# If all files were filtered out, nothing to do
if not filtered_filenames:
return
files_with_matches = []
if args.match_first_only:
files_with_matches = ["--files-with-matches"]
@ -223,30 +362,23 @@ def main() -> None:
try:
# Split the grep command into multiple batches to avoid hitting the
# command line length limit of ~1M on my machine
arg_length = sum(len(x) for x in args.filenames)
arg_length = sum(len(x) for x in filtered_filenames)
batches = arg_length // 750000 + 1
batch_size = len(args.filenames) // batches
for i in range(0, len(args.filenames), batch_size):
batch_size = len(filtered_filenames) // batches
for i in range(0, len(filtered_filenames), batch_size):
proc = run_command(
[
"grep",
"-nEHI",
*files_with_matches,
args.pattern,
*args.filenames[i : i + batch_size],
*filtered_filenames[i : i + batch_size],
]
)
lines.extend(proc.stdout.decode().splitlines())
except Exception as err:
err_msg = LintMessage(
path=None,
line=None,
char=None,
code=args.linter_name,
severity=LintSeverity.ERROR,
print_lint_message(
name="command-failed",
original=None,
replacement=None,
description=(
f"Failed due to {err.__class__.__name__}:\n{err}"
if not isinstance(err, subprocess.CalledProcessError)
@ -263,20 +395,19 @@ def main() -> None:
)
),
)
print(json.dumps(err_msg._asdict()), flush=True)
sys.exit(0)
for line in lines:
lint_message = lint_file(
line,
# Group lines by file to call lint_file once per file
grouped_lines = group_lines_by_file(lines)
for filename, line_remainders in grouped_lines.items():
lint_file(
filename,
line_remainders,
args.allowlist_pattern,
args.replace_pattern,
args.linter_name,
args.error_name,
args.error_description,
)
if lint_message is not None:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__":

View File

@ -7,6 +7,7 @@ from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from enum import Enum
from typing import NamedTuple
@ -15,6 +16,7 @@ from typing import NamedTuple
NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r"
LINTER_CODE = "NEWLINE"
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
class LintSeverity(str, Enum):
@ -39,6 +41,34 @@ class LintMessage(NamedTuple):
def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename)
# Check if file is too large
try:
file_size = os.path.getsize(filename)
if file_size > MAX_FILE_SIZE:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="file-too-large",
original=None,
replacement=None,
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
)
except OSError as err:
return LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="file-access-error",
original=None,
replacement=None,
description=f"Failed to get file size: {err}",
)
with open(filename, "rb") as f:
lines = f.readlines()