pytorch/tools/linter/adapters/gb_registry_linter.py
Sidharth 4d5d56a30e [dynamo] lintrunner for gb_registry adds/updates (#158460)
This PR adds automation to adding/updating the JSON registry through the lintrunner.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158460
Approved by: https://github.com/williamwen42
2025-07-23 21:02:54 +00:00

294 lines
9.2 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import argparse
import json
import sys
from enum import Enum
from pathlib import Path
from typing import Any, NamedTuple
REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.insert(0, str(REPO_ROOT))
from tools.dynamo.gb_id_mapping import (
find_unimplemented_v2_calls,
load_registry,
next_gb_id,
)
LINTER_CODE = "GB_REGISTRY"
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
def _collect_all_calls(
dynamo_dir: Path,
) -> dict[str, list[tuple[dict[str, Any], Path]]]:
"""Return mapping *gb_type → list[(call_info, file_path)]* for all occurrences."""
gb_type_calls: dict[str, list[tuple[dict[str, Any], Path]]] = {}
for py_file in dynamo_dir.rglob("*.py"):
for call in find_unimplemented_v2_calls(py_file, dynamo_dir):
gb_type = call["gb_type"]
if gb_type not in gb_type_calls:
gb_type_calls[gb_type] = []
gb_type_calls[gb_type].append((call, py_file))
return gb_type_calls
def _create_registry_entry(
gb_type: str, context: str, explanation: str, hints: list[str]
) -> dict[str, Any]:
"""Create a registry entry with consistent format."""
return {
"Gb_type": gb_type,
"Context": context,
"Explanation": explanation,
"Hints": hints or [],
}
def _update_registry_with_changes(
registry: dict,
calls: dict[str, tuple[dict[str, Any], Path]],
renames: dict[str, str] | None = None,
) -> dict:
"""Calculate what the updated registry should look like."""
renames = renames or {}
updated_registry = dict(registry)
latest_entry: dict[str, Any] = {
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
}
gb_type_to_key: dict[str, str] = {
entries[0]["Gb_type"]: key for key, entries in registry.items()
}
# Method for determining add vs. update:
# - If gb_type exists in registry but content differs: UPDATE (append new entry to preserve history)
# - If gb_type is new but content matches existing entry: RENAME (append new entry with new gb_type)
# - If gb_type is completely new: ADD (create new registry entry with a new GBID)
for old_gb_type, new_gb_type in renames.items():
registry_key = gb_type_to_key[old_gb_type]
old_entry = updated_registry[registry_key][0]
new_entry = _create_registry_entry(
new_gb_type,
old_entry["Context"],
old_entry["Explanation"],
old_entry["Hints"],
)
updated_registry[registry_key] = [new_entry] + updated_registry[registry_key]
latest_entry[new_gb_type] = new_entry
gb_type_to_key[new_gb_type] = registry_key
del latest_entry[old_gb_type]
del gb_type_to_key[old_gb_type]
for gb_type, (call, file_path) in calls.items():
if gb_type in latest_entry:
existing_entry = latest_entry[gb_type]
if not (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
):
registry_key = gb_type_to_key[gb_type]
new_entry = _create_registry_entry(
gb_type, call["context"], call["explanation"], call["hints"]
)
updated_registry[registry_key] = [new_entry] + updated_registry[
registry_key
]
else:
new_key = next_gb_id(updated_registry)
new_entry = _create_registry_entry(
gb_type, call["context"], call["explanation"], call["hints"]
)
updated_registry[new_key] = [new_entry]
return updated_registry
def check_registry_sync(dynamo_dir: Path, registry_path: Path) -> list[LintMessage]:
"""Check registry sync and return lint messages."""
lint_messages = []
all_calls = _collect_all_calls(dynamo_dir)
duplicates = []
for gb_type, call_list in all_calls.items():
if len(call_list) > 1:
first_call = call_list[0][0]
for call, file_path in call_list[1:]:
if (
call["context"] != first_call["context"]
or call["explanation"] != first_call["explanation"]
or sorted(call["hints"]) != sorted(first_call["hints"])
):
duplicates.append({"gb_type": gb_type, "calls": call_list})
break
for dup in duplicates:
gb_type = dup["gb_type"]
calls = dup["calls"]
description = f"The gb_type '{gb_type}' is used {len(calls)} times with different content. "
description += "Each gb_type must be unique across your entire codebase."
lint_messages.append(
LintMessage(
path=str(calls[0][1]),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Duplicate gb_type",
original=None,
replacement=None,
description=description,
)
)
if duplicates:
return lint_messages
calls = {gb_type: calls[0] for gb_type, calls in all_calls.items()}
registry = load_registry(registry_path)
latest_entry: dict[str, Any] = {
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
}
renames: dict[str, str] = {}
remaining_calls = dict(calls)
for gb_type, (call, file_path) in calls.items():
if gb_type not in latest_entry:
for existing_gb_type, existing_entry in latest_entry.items():
if (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
):
renames[existing_gb_type] = gb_type
del remaining_calls[gb_type]
break
needs_update = bool(renames)
for gb_type, (call, file_path) in remaining_calls.items():
if gb_type in latest_entry:
existing_entry = latest_entry[gb_type]
if not (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"] or []) == sorted(existing_entry["Hints"] or [])
):
needs_update = True
break
else:
needs_update = True
break
if needs_update:
updated_registry = _update_registry_with_changes(
registry, remaining_calls, renames
)
original_content = registry_path.read_text(encoding="utf-8")
replacement_content = (
json.dumps(updated_registry, indent=2, ensure_ascii=False) + "\n"
)
changes = []
if renames:
for old, new in renames.items():
changes.append(f"renamed '{old}''{new}'")
if remaining_calls:
new_count = sum(
1 for gb_type in remaining_calls if gb_type not in latest_entry
)
if new_count:
changes.append(f"added {new_count} new gb_types")
description = f"Registry sync needed ({', '.join(changes)}). Run `lintrunner -a` to apply changes."
lint_messages.append(
LintMessage(
path=str(registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=replacement_content,
description=description,
)
)
return lint_messages
if __name__ == "__main__":
script_dir = Path(__file__).resolve()
repo_root = script_dir.parents[3]
default_registry_path = (
repo_root / "torch" / "_dynamo" / "graph_break_registry.json"
)
default_dynamo_dir = repo_root / "torch" / "_dynamo"
parser = argparse.ArgumentParser(
description="Auto-sync graph break registry with source code"
)
parser.add_argument(
"--dynamo-dir",
type=Path,
default=default_dynamo_dir,
help=f"Path to the dynamo directory (default: {default_dynamo_dir})",
)
parser.add_argument(
"--registry-path",
type=Path,
default=default_registry_path,
help=f"Path to the registry file (default: {default_registry_path})",
)
args = parser.parse_args()
lint_messages = check_registry_sync(
dynamo_dir=args.dynamo_dir, registry_path=args.registry_path
)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)