mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR adds automation to adding/updating the JSON registry through the lintrunner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158460 Approved by: https://github.com/williamwen42
294 lines
9.2 KiB
Python
294 lines
9.2 KiB
Python
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, NamedTuple
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
|
|
from tools.dynamo.gb_id_mapping import (
|
|
find_unimplemented_v2_calls,
|
|
load_registry,
|
|
next_gb_id,
|
|
)
|
|
|
|
|
|
LINTER_CODE = "GB_REGISTRY"
|
|
|
|
|
|
class LintSeverity(str, Enum):
|
|
ERROR = "error"
|
|
WARNING = "warning"
|
|
ADVICE = "advice"
|
|
DISABLED = "disabled"
|
|
|
|
|
|
class LintMessage(NamedTuple):
|
|
path: str | None
|
|
line: int | None
|
|
char: int | None
|
|
code: str
|
|
severity: LintSeverity
|
|
name: str
|
|
original: str | None
|
|
replacement: str | None
|
|
description: str | None
|
|
|
|
|
|
def _collect_all_calls(
|
|
dynamo_dir: Path,
|
|
) -> dict[str, list[tuple[dict[str, Any], Path]]]:
|
|
"""Return mapping *gb_type → list[(call_info, file_path)]* for all occurrences."""
|
|
gb_type_calls: dict[str, list[tuple[dict[str, Any], Path]]] = {}
|
|
|
|
for py_file in dynamo_dir.rglob("*.py"):
|
|
for call in find_unimplemented_v2_calls(py_file, dynamo_dir):
|
|
gb_type = call["gb_type"]
|
|
if gb_type not in gb_type_calls:
|
|
gb_type_calls[gb_type] = []
|
|
gb_type_calls[gb_type].append((call, py_file))
|
|
|
|
return gb_type_calls
|
|
|
|
|
|
def _create_registry_entry(
|
|
gb_type: str, context: str, explanation: str, hints: list[str]
|
|
) -> dict[str, Any]:
|
|
"""Create a registry entry with consistent format."""
|
|
return {
|
|
"Gb_type": gb_type,
|
|
"Context": context,
|
|
"Explanation": explanation,
|
|
"Hints": hints or [],
|
|
}
|
|
|
|
|
|
def _update_registry_with_changes(
|
|
registry: dict,
|
|
calls: dict[str, tuple[dict[str, Any], Path]],
|
|
renames: dict[str, str] | None = None,
|
|
) -> dict:
|
|
"""Calculate what the updated registry should look like."""
|
|
renames = renames or {}
|
|
updated_registry = dict(registry)
|
|
|
|
latest_entry: dict[str, Any] = {
|
|
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
|
|
}
|
|
gb_type_to_key: dict[str, str] = {
|
|
entries[0]["Gb_type"]: key for key, entries in registry.items()
|
|
}
|
|
|
|
# Method for determining add vs. update:
|
|
# - If gb_type exists in registry but content differs: UPDATE (append new entry to preserve history)
|
|
# - If gb_type is new but content matches existing entry: RENAME (append new entry with new gb_type)
|
|
# - If gb_type is completely new: ADD (create new registry entry with a new GBID)
|
|
|
|
for old_gb_type, new_gb_type in renames.items():
|
|
registry_key = gb_type_to_key[old_gb_type]
|
|
old_entry = updated_registry[registry_key][0]
|
|
|
|
new_entry = _create_registry_entry(
|
|
new_gb_type,
|
|
old_entry["Context"],
|
|
old_entry["Explanation"],
|
|
old_entry["Hints"],
|
|
)
|
|
updated_registry[registry_key] = [new_entry] + updated_registry[registry_key]
|
|
|
|
latest_entry[new_gb_type] = new_entry
|
|
gb_type_to_key[new_gb_type] = registry_key
|
|
del latest_entry[old_gb_type]
|
|
del gb_type_to_key[old_gb_type]
|
|
|
|
for gb_type, (call, file_path) in calls.items():
|
|
if gb_type in latest_entry:
|
|
existing_entry = latest_entry[gb_type]
|
|
|
|
if not (
|
|
call["context"] == existing_entry["Context"]
|
|
and call["explanation"] == existing_entry["Explanation"]
|
|
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
|
|
):
|
|
registry_key = gb_type_to_key[gb_type]
|
|
new_entry = _create_registry_entry(
|
|
gb_type, call["context"], call["explanation"], call["hints"]
|
|
)
|
|
updated_registry[registry_key] = [new_entry] + updated_registry[
|
|
registry_key
|
|
]
|
|
else:
|
|
new_key = next_gb_id(updated_registry)
|
|
new_entry = _create_registry_entry(
|
|
gb_type, call["context"], call["explanation"], call["hints"]
|
|
)
|
|
updated_registry[new_key] = [new_entry]
|
|
|
|
return updated_registry
|
|
|
|
|
|
def check_registry_sync(dynamo_dir: Path, registry_path: Path) -> list[LintMessage]:
|
|
"""Check registry sync and return lint messages."""
|
|
lint_messages = []
|
|
|
|
all_calls = _collect_all_calls(dynamo_dir)
|
|
|
|
duplicates = []
|
|
for gb_type, call_list in all_calls.items():
|
|
if len(call_list) > 1:
|
|
first_call = call_list[0][0]
|
|
for call, file_path in call_list[1:]:
|
|
if (
|
|
call["context"] != first_call["context"]
|
|
or call["explanation"] != first_call["explanation"]
|
|
or sorted(call["hints"]) != sorted(first_call["hints"])
|
|
):
|
|
duplicates.append({"gb_type": gb_type, "calls": call_list})
|
|
break
|
|
|
|
for dup in duplicates:
|
|
gb_type = dup["gb_type"]
|
|
calls = dup["calls"]
|
|
|
|
description = f"The gb_type '{gb_type}' is used {len(calls)} times with different content. "
|
|
description += "Each gb_type must be unique across your entire codebase."
|
|
|
|
lint_messages.append(
|
|
LintMessage(
|
|
path=str(calls[0][1]),
|
|
line=None,
|
|
char=None,
|
|
code=LINTER_CODE,
|
|
severity=LintSeverity.ERROR,
|
|
name="Duplicate gb_type",
|
|
original=None,
|
|
replacement=None,
|
|
description=description,
|
|
)
|
|
)
|
|
|
|
if duplicates:
|
|
return lint_messages
|
|
|
|
calls = {gb_type: calls[0] for gb_type, calls in all_calls.items()}
|
|
|
|
registry = load_registry(registry_path)
|
|
latest_entry: dict[str, Any] = {
|
|
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
|
|
}
|
|
|
|
renames: dict[str, str] = {}
|
|
remaining_calls = dict(calls)
|
|
|
|
for gb_type, (call, file_path) in calls.items():
|
|
if gb_type not in latest_entry:
|
|
for existing_gb_type, existing_entry in latest_entry.items():
|
|
if (
|
|
call["context"] == existing_entry["Context"]
|
|
and call["explanation"] == existing_entry["Explanation"]
|
|
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
|
|
):
|
|
renames[existing_gb_type] = gb_type
|
|
del remaining_calls[gb_type]
|
|
break
|
|
|
|
needs_update = bool(renames)
|
|
|
|
for gb_type, (call, file_path) in remaining_calls.items():
|
|
if gb_type in latest_entry:
|
|
existing_entry = latest_entry[gb_type]
|
|
|
|
if not (
|
|
call["context"] == existing_entry["Context"]
|
|
and call["explanation"] == existing_entry["Explanation"]
|
|
and sorted(call["hints"] or []) == sorted(existing_entry["Hints"] or [])
|
|
):
|
|
needs_update = True
|
|
break
|
|
else:
|
|
needs_update = True
|
|
break
|
|
|
|
if needs_update:
|
|
updated_registry = _update_registry_with_changes(
|
|
registry, remaining_calls, renames
|
|
)
|
|
|
|
original_content = registry_path.read_text(encoding="utf-8")
|
|
|
|
replacement_content = (
|
|
json.dumps(updated_registry, indent=2, ensure_ascii=False) + "\n"
|
|
)
|
|
|
|
changes = []
|
|
if renames:
|
|
for old, new in renames.items():
|
|
changes.append(f"renamed '{old}' → '{new}'")
|
|
if remaining_calls:
|
|
new_count = sum(
|
|
1 for gb_type in remaining_calls if gb_type not in latest_entry
|
|
)
|
|
if new_count:
|
|
changes.append(f"added {new_count} new gb_types")
|
|
|
|
description = f"Registry sync needed ({', '.join(changes)}). Run `lintrunner -a` to apply changes."
|
|
|
|
lint_messages.append(
|
|
LintMessage(
|
|
path=str(registry_path),
|
|
line=None,
|
|
char=None,
|
|
code=LINTER_CODE,
|
|
severity=LintSeverity.WARNING,
|
|
name="Registry sync needed",
|
|
original=original_content,
|
|
replacement=replacement_content,
|
|
description=description,
|
|
)
|
|
)
|
|
|
|
return lint_messages
|
|
|
|
|
|
if __name__ == "__main__":
|
|
script_dir = Path(__file__).resolve()
|
|
repo_root = script_dir.parents[3]
|
|
default_registry_path = (
|
|
repo_root / "torch" / "_dynamo" / "graph_break_registry.json"
|
|
)
|
|
|
|
default_dynamo_dir = repo_root / "torch" / "_dynamo"
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Auto-sync graph break registry with source code"
|
|
)
|
|
parser.add_argument(
|
|
"--dynamo-dir",
|
|
type=Path,
|
|
default=default_dynamo_dir,
|
|
help=f"Path to the dynamo directory (default: {default_dynamo_dir})",
|
|
)
|
|
parser.add_argument(
|
|
"--registry-path",
|
|
type=Path,
|
|
default=default_registry_path,
|
|
help=f"Path to the registry file (default: {default_registry_path})",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
lint_messages = check_registry_sync(
|
|
dynamo_dir=args.dynamo_dir, registry_path=args.registry_path
|
|
)
|
|
|
|
for lint_message in lint_messages:
|
|
print(json.dumps(lint_message._asdict()), flush=True)
|