[Dynamo][Better Engineering] Typing torch/_dynamo/guards.py (#159315)

As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to `torch/_dynamo/guards.py`

Running
```
mypy torch/_dynamo/guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  2030 | 3945 | 51.46% | 70 | 138 | 50.72% |
| This PR | 4055 | 4055 | 100.00% | 138 | 138 | 100.00% |
| Delta    | +2025 | +90 | +48.54% | +68 | 0 | +49.28% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159315
Approved by: https://github.com/williamwen42, https://github.com/Skylion007
This commit is contained in:
Lucas Kabela 2025-08-06 21:52:14 +00:00 committed by PyTorch MergeBot
parent a5725965ea
commit 40c4d61f9a
6 changed files with 577 additions and 249 deletions

View File

@ -2,12 +2,9 @@ import enum
import types
from typing import Optional, overload
from torch._dynamo.types import (
DynamoCallback,
DynamoGuardCompleteHook,
DynamoGuardHook,
GuardFn,
)
from torch._dynamo.guards import GuardManagerWrapper
from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook
from torch._guards import CompileId
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
@ -25,14 +22,20 @@ def raise_sigtrap() -> None: ...
class _CacheEntry:
def check_fn(self, *args: object, **kwargs: object) -> bool: ...
def update_diff_guard_root_manager(self) -> None: ...
code: types.CodeType
compile_id: CompileId
# If we run into circular issues, just use object
guard_manager: GuardManagerWrapper
next: _CacheEntry | None
class _PrecompileEntry:
guard_manager: GuardFn
guard_manager: GuardManagerWrapper
class _ExtraState:
def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
def invalidate(
self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper
) -> None: ...
class _FrameAction(enum.IntEnum):
DEFAULT = 0
@ -69,7 +72,9 @@ py_opcode_caches: list[int]
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
def _load_precompile_entry(
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
code: types.CodeType,
guard_manager: GuardManagerWrapper,
dynamo_code: types.CodeType,
) -> None: ...
def _reset_precompile_entries(code: types.CodeType) -> None: ...
def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ...

View File

@ -7,8 +7,15 @@ class GlobalStateGuard:
def check(self) -> bool: ...
def reason(self) -> str: ...
class LeafGuard: ...
class GuardDebugInfo: ...
class LeafGuard:
def verbose_code_parts(self) -> list[str]: ...
class RelationalGuard: ...
class GuardDebugInfo:
verbose_code_parts: list[str]
result: bool
num_guards_executed: int
class GuardManager:
def check(self, value) -> bool: ...
@ -36,6 +43,84 @@ class GuardManager:
example_value,
guard_manager_enum,
) -> GuardManager: ...
def grad_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def generic_getattr_manager(
self,
attr: str,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def getitem_manager(
self,
key,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def get_generic_dict_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def list_getitem_manager(
self,
key,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def tuple_getitem_manager(
self,
key,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def set_getitem_manager(
self,
index,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def func_defaults_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def func_kwdefaults_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def tuple_iterator_getitem_manager(
self,
index,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def weakref_call_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def call_function_no_args_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def global_weakref_manager(
self,
global_name: str,
@ -91,7 +176,44 @@ class GuardManager:
example_value,
guard_manager_enum,
) -> GuardManager: ...
def get_root(self) -> RootGuardManager: ...
def get_source(self) -> str: ...
def fail_count(self) -> int: ...
def get_child_managers(self) -> list[GuardManager]: ...
def repr(self) -> str: ...
def type_of_guarded_value(self) -> str: ...
def get_leaf_guards(self) -> list[LeafGuard]: ...
def get_accessors(self) -> list[GuardManager]: ...
def is_guarded_value_immutable(self) -> bool: ...
def is_tag_safe(self) -> bool: ...
def is_tag_safe_root(self) -> bool: ...
def has_no_accessors(self) -> bool: ...
def has_object_aliasing_guard(self) -> bool: ...
def get_type_of_guarded_value(self) -> type: ...
def type_dict_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def type_mro_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def code_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def closure_manager(
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
# Leaf guards
def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
@ -106,7 +228,94 @@ class GuardManager:
def add_torch_function_mode_stack_guard(
self, initial_stack, verbose_code_parts: list[str]
) -> None: ...
def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
def add_mapping_keys_guard(self, value, verbose_code_parts: list[str]) -> None: ...
def add_dict_length_check_guard(
self, value, verbose_code_parts: list[str]
) -> None: ...
def add_length_check_guard(self, value, verbose_code_parts: list[str]) -> None: ...
def add_true_match_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_false_match_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_none_match_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_not_none_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def add_dispatch_key_set_guard(
self,
dispatch_key,
verbose_code_parts: list[str],
) -> None: ...
def add_tensor_match_guard(
self,
value,
sizes,
strides,
tensor_name,
verbose_code_parts: list[str],
ptype,
dispatch_keys,
) -> None: ...
def add_dynamic_indices_guard(
self,
value,
verbose_code_parts: list[str],
) -> None: ...
def add_no_hasattr_guard(
self,
attr_name,
verbose_code_parts: list[str],
) -> None: ...
def add_dict_contains_guard(
self,
contains,
key,
verbose_code_parts: list[str],
) -> None: ...
def add_type_match_guard(
self,
value,
verbose_code_parts: list[str],
) -> None: ...
def add_dict_version_guard(
self,
value,
verbose_code_parts: list[str],
) -> None: ...
def add_set_contains_guard(
self,
contains,
item,
verbose_code_parts: list[str],
) -> None: ...
def add_tuple_iterator_length_guard(
self,
length,
type_id,
verbose_code_parts: list[str],
) -> None: ...
def add_range_iterator_match_guard(
self,
start,
stop,
step,
type_id,
verbose_code_parts: list[str],
) -> None: ...
def add_default_device_guard(
self,
verbose_code_parts: list[str],
) -> None: ...
def mark_tag_safe(self) -> None: ...
def mark_tag_safe_root(self) -> None: ...
class RootGuardManager(GuardManager):
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
@ -118,6 +327,7 @@ class RootGuardManager(GuardManager):
def clone_manager(
self, clone_filter_fn: Callable[[GuardManager], bool]
) -> RootGuardManager: ...
def attach_compile_id(self, compile_id: str) -> None: ...
class DictGuardManager(GuardManager):
def get_key_manager(
@ -134,6 +344,9 @@ class DictGuardManager(GuardManager):
example_value,
guard_manager_enum,
) -> GuardManager: ...
def get_key_value_managers(
self,
) -> dict[int, tuple[GuardManager, GuardManager]]: ...
# Guard accessor stubs
class GuardAccessor: ...
@ -146,8 +359,8 @@ class GetAttrGuardAccessor(GuardAccessor):
def get_attr_name(self) -> str: ...
def install_object_aliasing_guard(
guard_managers: list[GuardManager],
tensor_names: list[str],
x: GuardManager,
y: GuardManager,
verbose_code_parts: list[str],
): ...
def install_no_tensor_aliasing_guard(

File diff suppressed because it is too large Load Diff

View File

@ -31,7 +31,7 @@ import re
import sys
import traceback
import weakref
from collections.abc import Generator
from collections.abc import Generator, Sequence
from dataclasses import dataclass, field as dc_field
from types import CodeType
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
@ -57,6 +57,7 @@ from torch._guards import (
)
from torch._subclasses.fake_tensor import FakeTensor
from torch._utils_internal import signpost_event
from torch.export.dynamic_shapes import _ConstraintTarget
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import (
@ -388,7 +389,7 @@ class OutputGraph(OutputGraphGuardsState):
compiler_fn: Optional[CompilerFn],
root_tx: "InstructionTranslatorBase",
export: bool,
export_constraints: Any,
export_constraints: Sequence[_ConstraintTarget],
frame_state: Any,
local_scope: Scope,
global_scope: Scope,
@ -414,7 +415,7 @@ class OutputGraph(OutputGraphGuardsState):
# de-duplicate graph inputs by source and reuse the tracker
self.input_source_to_var: dict[Source, VariableTracker] = {}
self.export = export
self.export_constraints = export_constraints
self.export_constraints = export_constraints # type: ignore[assignment]
self.frame_state = frame_state
self.cleanup_hooks: list[Callable[[], Any]] = []
# compile_id is an id number for the current torch.compile

View File

@ -206,7 +206,7 @@ def debug_insert_nops(
compiler_fn=None,
root_tx=None, # type: ignore[arg-type]
export=False,
export_constraints=None,
export_constraints=[],
frame_state={"_id": 0},
# TODO: shouldn't this be f_locals/f_globals from frame?
local_scope=locals(),

View File

@ -267,7 +267,7 @@ class Guard:
guard_types: Optional[list[str]] = None
code_list: Optional[list[str]] = None
obj_weakref: Optional[object] = None
guarded_class_weakref: Optional[type] = None
guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None
stack: Optional[CapturedTraceback] = None
user_stack: Optional[traceback.StackSummary] = None
@ -380,7 +380,7 @@ class Guard:
def set_export_info(
self,
guard_type: str,
guarded_class: Optional[type],
guarded_class: Optional[weakref.ReferenceType[Any]],
code_list: list[str],
obj_weakref: object,
) -> None: