mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add file size limits to linters and refactor grep_linter (#166202)
- Add 1GB file size limits to grep_linter, newlines_linter, codespell_linter - Refactor grep_linter - process files once instead of per-line - Extract allowlist check to separate function - Add 512KB limit for computing replacements, 100 match limit per file - Detect duplicate arguments - Fix .lintrunner.toml: RAWCUDADEVICE used --pattern twice Pull Request resolved: https://github.com/pytorch/pytorch/pull/166202 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
74e53d0761
commit
b55b779ad3
|
|
@ -833,8 +833,7 @@ exclude_patterns = [
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/grep_linter.py',
|
'tools/linter/adapters/grep_linter.py',
|
||||||
'--pattern=cudaSetDevice(',
|
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
||||||
'--pattern=cudaGetDevice(',
|
|
||||||
'--linter-name=RAWCUDADEVICE',
|
'--linter-name=RAWCUDADEVICE',
|
||||||
'--error-name=raw CUDA API usage',
|
'--error-name=raw CUDA API usage',
|
||||||
"""--error-description=\
|
"""--error-description=\
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ FORBIDDEN_WORDS = {
|
||||||
"multipy", # project pytorch/multipy is dead # codespell:ignore multipy
|
"multipy", # project pytorch/multipy is dead # codespell:ignore multipy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||||
|
|
||||||
|
|
||||||
class LintSeverity(str, Enum):
|
class LintSeverity(str, Enum):
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
@ -86,6 +88,39 @@ def run_codespell(path: Path) -> str:
|
||||||
|
|
||||||
def check_file(filename: str) -> list[LintMessage]:
|
def check_file(filename: str) -> list[LintMessage]:
|
||||||
path = Path(filename).absolute()
|
path = Path(filename).absolute()
|
||||||
|
|
||||||
|
# Check if file is too large
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(path)
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code="CODESPELL",
|
||||||
|
severity=LintSeverity.WARNING,
|
||||||
|
name="file-too-large",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
except OSError as err:
|
||||||
|
return [
|
||||||
|
LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code="CODESPELL",
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="file-access-error",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"Failed to get file size: {err}",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run_codespell(path)
|
run_codespell(path)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,11 @@ from typing import NamedTuple
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS: bool = os.name == "nt"
|
IS_WINDOWS: bool = os.name == "nt"
|
||||||
|
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||||
|
MAX_MATCHES_PER_FILE: int = 100 # Maximum number of matches to report per file
|
||||||
|
MAX_ORIGINAL_SIZE: int = (
|
||||||
|
512 * 1024
|
||||||
|
) # 512KB - don't compute replacement if original is larger
|
||||||
|
|
||||||
|
|
||||||
class LintSeverity(str, Enum):
|
class LintSeverity(str, Enum):
|
||||||
|
|
@ -25,6 +30,10 @@ class LintSeverity(str, Enum):
|
||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
|
|
||||||
|
|
||||||
|
LINTER_NAME: str = ""
|
||||||
|
ERROR_DESCRIPTION: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LintMessage(NamedTuple):
|
class LintMessage(NamedTuple):
|
||||||
path: str | None
|
path: str | None
|
||||||
line: int | None
|
line: int | None
|
||||||
|
|
@ -56,32 +65,76 @@ def run_command(
|
||||||
logging.debug("took %dms", (end_time - start_time) * 1000)
|
logging.debug("took %dms", (end_time - start_time) * 1000)
|
||||||
|
|
||||||
|
|
||||||
def lint_file(
|
def print_lint_message(
|
||||||
matching_line: str,
|
name: str,
|
||||||
allowlist_pattern: str,
|
severity: LintSeverity = LintSeverity.ERROR,
|
||||||
replace_pattern: str,
|
path: str | None = None,
|
||||||
linter_name: str,
|
line: int | None = None,
|
||||||
error_name: str,
|
original: str | None = None,
|
||||||
error_description: str,
|
replacement: str | None = None,
|
||||||
) -> LintMessage | None:
|
description: str | None = None,
|
||||||
# matching_line looks like:
|
) -> None:
|
||||||
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
|
"""
|
||||||
split = matching_line.split(":")
|
Create a LintMessage and print it as JSON.
|
||||||
filename = split[0]
|
|
||||||
|
Accepts the same arguments as LintMessage constructor.
|
||||||
|
"""
|
||||||
|
char = None
|
||||||
|
code = LINTER_NAME
|
||||||
|
description = description or ERROR_DESCRIPTION
|
||||||
|
lint_message = LintMessage(
|
||||||
|
path, line, char, code, severity, name, original, replacement, description
|
||||||
|
)
|
||||||
|
print(json.dumps(lint_message._asdict()), flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def group_lines_by_file(lines: list[str]) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Group matching lines by filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines: List of grep output lines in format "filename:line:content"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping filename to list of line remainders (without filename prefix)
|
||||||
|
"""
|
||||||
|
grouped: dict[str, list[str]] = {}
|
||||||
|
for line in lines:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
# Extract filename and remainder from "filename:line:content" format
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
filename = parts[0]
|
||||||
|
remainder = parts[1] if len(parts) > 1 else ""
|
||||||
|
if filename not in grouped:
|
||||||
|
grouped[filename] = []
|
||||||
|
grouped[filename].append(remainder)
|
||||||
|
return grouped
|
||||||
|
|
||||||
|
|
||||||
|
def check_allowlist(
|
||||||
|
filename: str,
|
||||||
|
allowlist_pattern: str,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file matches the allowlist pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to the file to check
|
||||||
|
allowlist_pattern: Pattern to grep for in the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file should be skipped (allowlist pattern matched), False otherwise.
|
||||||
|
Prints error message and returns False if there was an error running grep.
|
||||||
|
"""
|
||||||
|
if not allowlist_pattern:
|
||||||
|
return False
|
||||||
|
|
||||||
if allowlist_pattern:
|
|
||||||
try:
|
try:
|
||||||
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
|
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return LintMessage(
|
print_lint_message(
|
||||||
path=None,
|
|
||||||
line=None,
|
|
||||||
char=None,
|
|
||||||
code=linter_name,
|
|
||||||
severity=LintSeverity.ERROR,
|
|
||||||
name="command-failed",
|
name="command-failed",
|
||||||
original=None,
|
|
||||||
replacement=None,
|
|
||||||
description=(
|
description=(
|
||||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||||
if not isinstance(err, subprocess.CalledProcessError)
|
if not isinstance(err, subprocess.CalledProcessError)
|
||||||
|
|
@ -98,30 +151,57 @@ def lint_file(
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# allowlist pattern was found, abort lint
|
# allowlist pattern was found, abort lint
|
||||||
if proc.returncode == 0:
|
if proc.returncode == 0:
|
||||||
return None
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def lint_file(
|
||||||
|
filename: str,
|
||||||
|
line_remainders: list[str],
|
||||||
|
allowlist_pattern: str,
|
||||||
|
replace_pattern: str,
|
||||||
|
error_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Lint a file with one or more pattern matches, printing LintMessages as they're created.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Path to the file being linted
|
||||||
|
line_remainders: List of line remainders (format: "line:content" without filename prefix)
|
||||||
|
allowlist_pattern: Pattern to check for allowlisting
|
||||||
|
replace_pattern: Pattern for sed replacement
|
||||||
|
error_name: Human-readable error name
|
||||||
|
"""
|
||||||
|
if not line_remainders:
|
||||||
|
return
|
||||||
|
|
||||||
|
should_skip = check_allowlist(filename, allowlist_pattern)
|
||||||
|
if should_skip:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if file is too large to compute replacement
|
||||||
|
file_size = os.path.getsize(filename)
|
||||||
|
compute_replacement = replace_pattern and file_size <= MAX_ORIGINAL_SIZE
|
||||||
|
|
||||||
|
# Apply replacement to entire file if pattern is specified and file is not too large
|
||||||
original = None
|
original = None
|
||||||
replacement = None
|
replacement = None
|
||||||
if replace_pattern:
|
if compute_replacement:
|
||||||
|
# When we have a replacement, report a single message with line=None
|
||||||
|
try:
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
original = f.read()
|
original = f.read()
|
||||||
|
|
||||||
try:
|
|
||||||
proc = run_command(["sed", "-r", replace_pattern, filename])
|
proc = run_command(["sed", "-r", replace_pattern, filename])
|
||||||
replacement = proc.stdout.decode("utf-8")
|
replacement = proc.stdout.decode("utf-8")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return LintMessage(
|
print_lint_message(
|
||||||
path=None,
|
|
||||||
line=None,
|
|
||||||
char=None,
|
|
||||||
code=linter_name,
|
|
||||||
severity=LintSeverity.ERROR,
|
|
||||||
name="command-failed",
|
name="command-failed",
|
||||||
original=None,
|
|
||||||
replacement=None,
|
|
||||||
description=(
|
description=(
|
||||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||||
if not isinstance(err, subprocess.CalledProcessError)
|
if not isinstance(err, subprocess.CalledProcessError)
|
||||||
|
|
@ -138,17 +218,35 @@ def lint_file(
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
return LintMessage(
|
print_lint_message(
|
||||||
path=split[0],
|
path=filename,
|
||||||
line=int(split[1]) if len(split) > 1 else None,
|
|
||||||
char=None,
|
|
||||||
code=linter_name,
|
|
||||||
severity=LintSeverity.ERROR,
|
|
||||||
name=error_name,
|
name=error_name,
|
||||||
original=original,
|
original=original,
|
||||||
replacement=replacement,
|
replacement=replacement,
|
||||||
description=error_description,
|
)
|
||||||
|
else:
|
||||||
|
# When no replacement, report each matching line (up to MAX_MATCHES_PER_FILE)
|
||||||
|
total_matches = len(line_remainders)
|
||||||
|
matches_to_report = min(total_matches, MAX_MATCHES_PER_FILE)
|
||||||
|
|
||||||
|
for line_remainder in line_remainders[:matches_to_report]:
|
||||||
|
# line_remainder format: "line_number:content"
|
||||||
|
split = line_remainder.split(":", 1)
|
||||||
|
line_number = int(split[0]) if split[0] else None
|
||||||
|
print_lint_message(
|
||||||
|
path=filename,
|
||||||
|
line=line_number,
|
||||||
|
name=error_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there are more matches than the limit, print an error
|
||||||
|
if total_matches > MAX_MATCHES_PER_FILE:
|
||||||
|
print_lint_message(
|
||||||
|
path=filename,
|
||||||
|
name="too-many-matches",
|
||||||
|
description=f"File has {total_matches} matches, only showing first {MAX_MATCHES_PER_FILE}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -203,8 +301,24 @@ def main() -> None:
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="paths to lint",
|
help="paths to lint",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check for duplicate arguments before parsing
|
||||||
|
seen_args = set()
|
||||||
|
for arg in sys.argv[1:]:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
arg_name = arg.split("=")[0]
|
||||||
|
if arg_name in seen_args:
|
||||||
|
parser.error(
|
||||||
|
f"argument {arg_name}: not allowed to be specified multiple times"
|
||||||
|
)
|
||||||
|
seen_args.add(arg_name)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
global LINTER_NAME, ERROR_DESCRIPTION
|
||||||
|
LINTER_NAME = args.linter_name
|
||||||
|
ERROR_DESCRIPTION = args.error_description
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="<%(threadName)s:%(levelname)s> %(message)s",
|
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||||
level=logging.NOTSET
|
level=logging.NOTSET
|
||||||
|
|
@ -215,6 +329,31 @@ def main() -> None:
|
||||||
stream=sys.stderr,
|
stream=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out files that are too large before running grep
|
||||||
|
filtered_filenames = []
|
||||||
|
for filename in args.filenames:
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(filename)
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
print_lint_message(
|
||||||
|
path=filename,
|
||||||
|
severity=LintSeverity.WARNING,
|
||||||
|
name="file-too-large",
|
||||||
|
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filtered_filenames.append(filename)
|
||||||
|
except OSError as err:
|
||||||
|
print_lint_message(
|
||||||
|
path=filename,
|
||||||
|
name="file-access-error",
|
||||||
|
description=f"Failed to get file size: {err}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# If all files were filtered out, nothing to do
|
||||||
|
if not filtered_filenames:
|
||||||
|
return
|
||||||
|
|
||||||
files_with_matches = []
|
files_with_matches = []
|
||||||
if args.match_first_only:
|
if args.match_first_only:
|
||||||
files_with_matches = ["--files-with-matches"]
|
files_with_matches = ["--files-with-matches"]
|
||||||
|
|
@ -223,30 +362,23 @@ def main() -> None:
|
||||||
try:
|
try:
|
||||||
# Split the grep command into multiple batches to avoid hitting the
|
# Split the grep command into multiple batches to avoid hitting the
|
||||||
# command line length limit of ~1M on my machine
|
# command line length limit of ~1M on my machine
|
||||||
arg_length = sum(len(x) for x in args.filenames)
|
arg_length = sum(len(x) for x in filtered_filenames)
|
||||||
batches = arg_length // 750000 + 1
|
batches = arg_length // 750000 + 1
|
||||||
batch_size = len(args.filenames) // batches
|
batch_size = len(filtered_filenames) // batches
|
||||||
for i in range(0, len(args.filenames), batch_size):
|
for i in range(0, len(filtered_filenames), batch_size):
|
||||||
proc = run_command(
|
proc = run_command(
|
||||||
[
|
[
|
||||||
"grep",
|
"grep",
|
||||||
"-nEHI",
|
"-nEHI",
|
||||||
*files_with_matches,
|
*files_with_matches,
|
||||||
args.pattern,
|
args.pattern,
|
||||||
*args.filenames[i : i + batch_size],
|
*filtered_filenames[i : i + batch_size],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
lines.extend(proc.stdout.decode().splitlines())
|
lines.extend(proc.stdout.decode().splitlines())
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
err_msg = LintMessage(
|
print_lint_message(
|
||||||
path=None,
|
|
||||||
line=None,
|
|
||||||
char=None,
|
|
||||||
code=args.linter_name,
|
|
||||||
severity=LintSeverity.ERROR,
|
|
||||||
name="command-failed",
|
name="command-failed",
|
||||||
original=None,
|
|
||||||
replacement=None,
|
|
||||||
description=(
|
description=(
|
||||||
f"Failed due to {err.__class__.__name__}:\n{err}"
|
f"Failed due to {err.__class__.__name__}:\n{err}"
|
||||||
if not isinstance(err, subprocess.CalledProcessError)
|
if not isinstance(err, subprocess.CalledProcessError)
|
||||||
|
|
@ -263,20 +395,19 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
print(json.dumps(err_msg._asdict()), flush=True)
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
for line in lines:
|
# Group lines by file to call lint_file once per file
|
||||||
lint_message = lint_file(
|
grouped_lines = group_lines_by_file(lines)
|
||||||
line,
|
|
||||||
|
for filename, line_remainders in grouped_lines.items():
|
||||||
|
lint_file(
|
||||||
|
filename,
|
||||||
|
line_remainders,
|
||||||
args.allowlist_pattern,
|
args.allowlist_pattern,
|
||||||
args.replace_pattern,
|
args.replace_pattern,
|
||||||
args.linter_name,
|
|
||||||
args.error_name,
|
args.error_name,
|
||||||
args.error_description,
|
|
||||||
)
|
)
|
||||||
if lint_message is not None:
|
|
||||||
print(json.dumps(lint_message._asdict()), flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
@ -15,6 +16,7 @@ from typing import NamedTuple
|
||||||
NEWLINE = 10 # ASCII "\n"
|
NEWLINE = 10 # ASCII "\n"
|
||||||
CARRIAGE_RETURN = 13 # ASCII "\r"
|
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||||
LINTER_CODE = "NEWLINE"
|
LINTER_CODE = "NEWLINE"
|
||||||
|
MAX_FILE_SIZE: int = 1024 * 1024 * 1024 # 1GB in bytes
|
||||||
|
|
||||||
|
|
||||||
class LintSeverity(str, Enum):
|
class LintSeverity(str, Enum):
|
||||||
|
|
@ -39,6 +41,34 @@ class LintMessage(NamedTuple):
|
||||||
def check_file(filename: str) -> LintMessage | None:
|
def check_file(filename: str) -> LintMessage | None:
|
||||||
logging.debug("Checking file %s", filename)
|
logging.debug("Checking file %s", filename)
|
||||||
|
|
||||||
|
# Check if file is too large
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(filename)
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
return LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.WARNING,
|
||||||
|
name="file-too-large",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"File size ({file_size} bytes) exceeds {MAX_FILE_SIZE} bytes limit, skipping",
|
||||||
|
)
|
||||||
|
except OSError as err:
|
||||||
|
return LintMessage(
|
||||||
|
path=filename,
|
||||||
|
line=None,
|
||||||
|
char=None,
|
||||||
|
code=LINTER_CODE,
|
||||||
|
severity=LintSeverity.ERROR,
|
||||||
|
name="file-access-error",
|
||||||
|
original=None,
|
||||||
|
replacement=None,
|
||||||
|
description=f"Failed to get file size: {err}",
|
||||||
|
)
|
||||||
|
|
||||||
with open(filename, "rb") as f:
|
with open(filename, "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user