mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add file size limits to linters and refactor grep_linter (#166202)
- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter - Refactor grep_linter - process files once instead of per-line - Extract allowlist check to separate function - Add 512KB limit for computing replacements, 100 match limit per file - Detect duplicate arguments - Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
74e53d0761
commit
b55b779ad3
|
|
@ -833,8 +833,7 @@ exclude_patterns = [
|
|||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/grep_linter.py',
|
||||
'--pattern=cudaSetDevice(',
|
||||
'--pattern=cudaGetDevice(',
|
||||
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
||||
'--linter-name=RAWCUDADEVICE',
|
||||
'--error-name=raw CUDA API usage',
|
||||
"""--error-description=\
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ FORBIDDEN_WORDS = {
|
|||
"multipy", # project pytorch/multipy is dead # codespell:ignore multipy
|
||||
}
|
||||
|
||||
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
|
|
@ -86,6 +88,39 @@ def run_codespell(path: Path) -> str:
|
|||
|
||||
def check_file(filename: str) -> list[LintMessage]:
|
||||
path = Path(filename).absolute()
|
||||
|
||||
# Check if file is too large
|
||||
try:
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return [
|
||||
LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code="CODESPELL",
|
||||
severity=LintSeverity.WARNING,
|
||||
name="file-too-large",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||
)
|
||||
]
|
||||
except OSError as err:
|
||||
return [
|
||||
LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code="CODESPELL",
|
||||
severity=LintSeverity.ERROR,
|
||||
name="file-access-error",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"Failed to get file size: {err}",
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
run_codespell(path)
|
||||
except Exception as err:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,11 @@ from typing import NamedTuple
|
|||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||
MAX_MATCHES_PER_FILE: int = 100 # Maximum number of matches to report per file
|
||||
MAX_ORIGINAL_SIZE: int = (
|
||||
512 * 1024
|
||||
) # 512KB - don't compute replacement if original is larger
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
|
|
@ -25,6 +30,10 @@ class LintSeverity(str, Enum):
|
|||
DISABLED = "disabled"
|
||||
|
||||
|
||||
LINTER_NAME: str = ""
|
||||
ERROR_DESCRIPTION: str | None = None
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: str | None
|
||||
line: int | None
|
||||
|
|
@ -56,32 +65,76 @@ def run_command(
|
|||
logging.debug("took %dms", (end_time - start_time) * 1000)
|
||||
|
||||
|
||||
def lint_file(
|
||||
matching_line: str,
|
||||
allowlist_pattern: str,
|
||||
replace_pattern: str,
|
||||
linter_name: str,
|
||||
error_name: str,
|
||||
error_description: str,
|
||||
) -> LintMessage | None:
|
||||
# matching_line looks like:
|
||||
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
|
||||
split = matching_line.split(":")
|
||||
filename = split[0]
|
||||
def print_lint_message(
|
||||
name: str,
|
||||
severity: LintSeverity = LintSeverity.ERROR,
|
||||
path: str | None = None,
|
||||
line: int | None = None,
|
||||
original: str | None = None,
|
||||
replacement: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a LintMessage and print it as JSON.
|
||||
|
||||
Accepts the same arguments as LintMessage constructor.
|
||||
"""
|
||||
char = None
|
||||
code = LINTER_NAME
|
||||
description = description or ERROR_DESCRIPTION
|
||||
lint_message = LintMessage(
|
||||
path, line, char, code, severity, name, original, replacement, description
|
||||
)
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
||||
|
||||
|
||||
def group_lines_by_file(lines: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Group matching lines by filename.
|
||||
|
||||
Args:
|
||||
lines: List of grep output lines in format "filename:line:content"
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filename to list of line remainders (without filename prefix)
|
||||
"""
|
||||
grouped: dict[str, list[str]] = {}
|
||||
for line in lines:
|
||||
if not line:
|
||||
continue
|
||||
# Extract filename and remainder from "filename:line:content" format
|
||||
parts = line.split(":", 1)
|
||||
filename = parts[0]
|
||||
remainder = parts[1] if len(parts) > 1 else ""
|
||||
if filename not in grouped:
|
||||
grouped[filename] = []
|
||||
grouped[filename].append(remainder)
|
||||
return grouped
|
||||
|
||||
|
||||
def check_allowlist(
|
||||
filename: str,
|
||||
allowlist_pattern: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a file matches the allowlist pattern.
|
||||
|
||||
Args:
|
||||
filename: Path to the file to check
|
||||
allowlist_pattern: Pattern to grep for in the file
|
||||
|
||||
Returns:
|
||||
True if the file should be skipped (allowlist pattern matched), False otherwise.
|
||||
Prints error message and returns False if there was an error running grep.
|
||||
"""
|
||||
if not allowlist_pattern:
|
||||
return False
|
||||
|
||||
if allowlist_pattern:
|
||||
try:
|
||||
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
|
||||
except Exception as err:
|
||||
return LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=linter_name,
|
||||
severity=LintSeverity.ERROR,
|
||||
print_lint_message(
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||
if not isinstance(err, subprocess.CalledProcessError)
|
||||
|
|
@ -98,30 +151,57 @@ def lint_file(
|
|||
)
|
||||
),
|
||||
)
|
||||
return False
|
||||
|
||||
# allowlist pattern was found, abort lint
|
||||
if proc.returncode == 0:
|
||||
return None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def lint_file(
|
||||
filename: str,
|
||||
line_remainders: list[str],
|
||||
allowlist_pattern: str,
|
||||
replace_pattern: str,
|
||||
error_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Lint a file with one or more pattern matches, printing LintMessages as they're created.
|
||||
|
||||
Args:
|
||||
filename: Path to the file being linted
|
||||
line_remainders: List of line remainders (format: "line:content" without filename prefix)
|
||||
allowlist_pattern: Pattern to check for allowlisting
|
||||
replace_pattern: Pattern for sed replacement
|
||||
error_name: Human-readable error name
|
||||
"""
|
||||
if not line_remainders:
|
||||
return
|
||||
|
||||
should_skip = check_allowlist(filename, allowlist_pattern)
|
||||
if should_skip:
|
||||
return
|
||||
|
||||
# Check if file is too large to compute replacement
|
||||
file_size = os.path.getsize(filename)
|
||||
compute_replacement = replace_pattern and file_size <= MAX_ORIGINAL_SIZE
|
||||
|
||||
# Apply replacement to entire file if pattern is specified and file is not too large
|
||||
original = None
|
||||
replacement = None
|
||||
if replace_pattern:
|
||||
if compute_replacement:
|
||||
# When we have a replacement, report a single message with line=None
|
||||
try:
|
||||
with open(filename) as f:
|
||||
original = f.read()
|
||||
|
||||
try:
|
||||
proc = run_command(["sed", "-r", replace_pattern, filename])
|
||||
replacement = proc.stdout.decode("utf-8")
|
||||
except Exception as err:
|
||||
return LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=linter_name,
|
||||
severity=LintSeverity.ERROR,
|
||||
print_lint_message(
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||
if not isinstance(err, subprocess.CalledProcessError)
|
||||
|
|
@ -138,17 +218,35 @@ def lint_file(
|
|||
)
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
return LintMessage(
|
||||
path=split[0],
|
||||
line=int(split[1]) if len(split) > 1 else None,
|
||||
char=None,
|
||||
code=linter_name,
|
||||
severity=LintSeverity.ERROR,
|
||||
print_lint_message(
|
||||
path=filename,
|
||||
name=error_name,
|
||||
original=original,
|
||||
replacement=replacement,
|
||||
description=error_description,
|
||||
)
|
||||
else:
|
||||
# When no replacement, report each matching line (up to MAX_MATCHES_PER_FILE)
|
||||
total_matches = len(line_remainders)
|
||||
matches_to_report = min(total_matches, MAX_MATCHES_PER_FILE)
|
||||
|
||||
for line_remainder in line_remainders[:matches_to_report]:
|
||||
# line_remainder format: "line_number:content"
|
||||
split = line_remainder.split(":", 1)
|
||||
line_number = int(split[0]) if split[0] else None
|
||||
print_lint_message(
|
||||
path=filename,
|
||||
line=line_number,
|
||||
name=error_name,
|
||||
)
|
||||
|
||||
# If there are more matches than the limit, print an error
|
||||
if total_matches > MAX_MATCHES_PER_FILE:
|
||||
print_lint_message(
|
||||
path=filename,
|
||||
name="too-many-matches",
|
||||
description=f"File has {total_matches} matches, only showing first {MAX_MATCHES_PER_FILE}",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -203,8 +301,24 @@ def main() -> None:
|
|||
nargs="+",
|
||||
help="paths to lint",
|
||||
)
|
||||
|
||||
# Check for duplicate arguments before parsing
|
||||
seen_args = set()
|
||||
for arg in sys.argv[1:]:
|
||||
if arg.startswith("--"):
|
||||
arg_name = arg.split("=")[0]
|
||||
if arg_name in seen_args:
|
||||
parser.error(
|
||||
f"argument {arg_name}: not allowed to be specified multiple times"
|
||||
)
|
||||
seen_args.add(arg_name)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
global LINTER_NAME, ERROR_DESCRIPTION
|
||||
LINTER_NAME = args.linter_name
|
||||
ERROR_DESCRIPTION = args.error_description
|
||||
|
||||
logging.basicConfig(
|
||||
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||
level=logging.NOTSET
|
||||
|
|
@ -215,6 +329,31 @@ def main() -> None:
|
|||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
# Filter out files that are too large before running grep
|
||||
filtered_filenames = []
|
||||
for filename in args.filenames:
|
||||
try:
|
||||
file_size = os.path.getsize(filename)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
print_lint_message(
|
||||
path=filename,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="file-too-large",
|
||||
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||
)
|
||||
else:
|
||||
filtered_filenames.append(filename)
|
||||
except OSError as err:
|
||||
print_lint_message(
|
||||
path=filename,
|
||||
name="file-access-error",
|
||||
description=f"Failed to get file size: {err}",
|
||||
)
|
||||
|
||||
# If all files were filtered out, nothing to do
|
||||
if not filtered_filenames:
|
||||
return
|
||||
|
||||
files_with_matches = []
|
||||
if args.match_first_only:
|
||||
files_with_matches = ["--files-with-matches"]
|
||||
|
|
@ -223,30 +362,23 @@ def main() -> None:
|
|||
try:
|
||||
# Split the grep command into multiple batches to avoid hitting the
|
||||
# command line length limit of ~1M on my machine
|
||||
arg_length = sum(len(x) for x in args.filenames)
|
||||
arg_length = sum(len(x) for x in filtered_filenames)
|
||||
batches = arg_length // 750000 + 1
|
||||
batch_size = len(args.filenames) // batches
|
||||
for i in range(0, len(args.filenames), batch_size):
|
||||
batch_size = len(filtered_filenames) // batches
|
||||
for i in range(0, len(filtered_filenames), batch_size):
|
||||
proc = run_command(
|
||||
[
|
||||
"grep",
|
||||
"-nEHI",
|
||||
*files_with_matches,
|
||||
args.pattern,
|
||||
*args.filenames[i : i + batch_size],
|
||||
*filtered_filenames[i : i + batch_size],
|
||||
]
|
||||
)
|
||||
lines.extend(proc.stdout.decode().splitlines())
|
||||
except Exception as err:
|
||||
err_msg = LintMessage(
|
||||
path=None,
|
||||
line=None,
|
||||
char=None,
|
||||
code=args.linter_name,
|
||||
severity=LintSeverity.ERROR,
|
||||
print_lint_message(
|
||||
name="command-failed",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||
if not isinstance(err, subprocess.CalledProcessError)
|
||||
|
|
@ -263,20 +395,19 @@ def main() -> None:
|
|||
)
|
||||
),
|
||||
)
|
||||
print(json.dumps(err_msg._asdict()), flush=True)
|
||||
sys.exit(0)
|
||||
|
||||
for line in lines:
|
||||
lint_message = lint_file(
|
||||
line,
|
||||
# Group lines by file to call lint_file once per file
|
||||
grouped_lines = group_lines_by_file(lines)
|
||||
|
||||
for filename, line_remainders in grouped_lines.items():
|
||||
lint_file(
|
||||
filename,
|
||||
line_remainders,
|
||||
args.allowlist_pattern,
|
||||
args.replace_pattern,
|
||||
args.linter_name,
|
||||
args.error_name,
|
||||
args.error_description,
|
||||
)
|
||||
if lint_message is not None:
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from __future__ import annotations
|
|||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
|
@ -15,6 +16,7 @@ from typing import NamedTuple
|
|||
NEWLINE = 10 # ASCII "\n"
|
||||
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||
LINTER_CODE = "NEWLINE"
|
||||
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
|
|
@ -39,6 +41,34 @@ class LintMessage(NamedTuple):
|
|||
def check_file(filename: str) -> LintMessage | None:
|
||||
logging.debug("Checking file %s", filename)
|
||||
|
||||
# Check if file is too large
|
||||
try:
|
||||
file_size = os.path.getsize(filename)
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="file-too-large",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||
)
|
||||
except OSError as err:
|
||||
return LintMessage(
|
||||
path=filename,
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="file-access-error",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=f"Failed to get file size: {err}",
|
||||
)
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user