mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo][Better Engineering] Typing torch/_dynamo/guards.py (#159315)
As part of better engineering effort, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to `torch/_dynamo/guards.py` Running ``` mypy torch/_dynamo/guards.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Annotated | Lines Total | % lines covered | Funcs Annotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 2030 | 3945 | 51.46% | 70 | 138 | 50.72% | | This PR | 4055 | 4055 | 100.00% | 138 | 138 | 100.00% | | Delta | +2025 | +90 | +48.54% | +68 | 0 | +49.28% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/159315 Approved by: https://github.com/williamwen42, https://github.com/Skylion007
This commit is contained in:
parent
a5725965ea
commit
40c4d61f9a
|
|
@ -2,12 +2,9 @@ import enum
|
|||
import types
|
||||
from typing import Optional, overload
|
||||
|
||||
from torch._dynamo.types import (
|
||||
DynamoCallback,
|
||||
DynamoGuardCompleteHook,
|
||||
DynamoGuardHook,
|
||||
GuardFn,
|
||||
)
|
||||
from torch._dynamo.guards import GuardManagerWrapper
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook
|
||||
from torch._guards import CompileId
|
||||
|
||||
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
||||
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
|
||||
|
|
@ -25,14 +22,20 @@ def raise_sigtrap() -> None: ...
|
|||
|
||||
class _CacheEntry:
|
||||
def check_fn(self, *args: object, **kwargs: object) -> bool: ...
|
||||
def update_diff_guard_root_manager(self) -> None: ...
|
||||
code: types.CodeType
|
||||
compile_id: CompileId
|
||||
# If we run into circular issues, just use object
|
||||
guard_manager: GuardManagerWrapper
|
||||
next: _CacheEntry | None
|
||||
|
||||
class _PrecompileEntry:
|
||||
guard_manager: GuardFn
|
||||
guard_manager: GuardManagerWrapper
|
||||
|
||||
class _ExtraState:
|
||||
def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
|
||||
def invalidate(
|
||||
self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper
|
||||
) -> None: ...
|
||||
|
||||
class _FrameAction(enum.IntEnum):
|
||||
DEFAULT = 0
|
||||
|
|
@ -69,7 +72,9 @@ py_opcode_caches: list[int]
|
|||
|
||||
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
|
||||
def _load_precompile_entry(
|
||||
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
|
||||
code: types.CodeType,
|
||||
guard_manager: GuardManagerWrapper,
|
||||
dynamo_code: types.CodeType,
|
||||
) -> None: ...
|
||||
def _reset_precompile_entries(code: types.CodeType) -> None: ...
|
||||
def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ...
|
||||
|
|
|
|||
|
|
@ -7,8 +7,15 @@ class GlobalStateGuard:
|
|||
def check(self) -> bool: ...
|
||||
def reason(self) -> str: ...
|
||||
|
||||
class LeafGuard: ...
|
||||
class GuardDebugInfo: ...
|
||||
class LeafGuard:
|
||||
def verbose_code_parts(self) -> list[str]: ...
|
||||
|
||||
class RelationalGuard: ...
|
||||
|
||||
class GuardDebugInfo:
|
||||
verbose_code_parts: list[str]
|
||||
result: bool
|
||||
num_guards_executed: int
|
||||
|
||||
class GuardManager:
|
||||
def check(self, value) -> bool: ...
|
||||
|
|
@ -36,6 +43,84 @@ class GuardManager:
|
|||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def grad_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def generic_getattr_manager(
|
||||
self,
|
||||
attr: str,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def getitem_manager(
|
||||
self,
|
||||
key,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def get_generic_dict_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def list_getitem_manager(
|
||||
self,
|
||||
key,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def tuple_getitem_manager(
|
||||
self,
|
||||
key,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def set_getitem_manager(
|
||||
self,
|
||||
index,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def func_defaults_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def func_kwdefaults_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def tuple_iterator_getitem_manager(
|
||||
self,
|
||||
index,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def weakref_call_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def call_function_no_args_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def global_weakref_manager(
|
||||
self,
|
||||
global_name: str,
|
||||
|
|
@ -91,7 +176,44 @@ class GuardManager:
|
|||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
|
||||
def get_root(self) -> RootGuardManager: ...
|
||||
def get_source(self) -> str: ...
|
||||
def fail_count(self) -> int: ...
|
||||
def get_child_managers(self) -> list[GuardManager]: ...
|
||||
def repr(self) -> str: ...
|
||||
def type_of_guarded_value(self) -> str: ...
|
||||
def get_leaf_guards(self) -> list[LeafGuard]: ...
|
||||
def get_accessors(self) -> list[GuardManager]: ...
|
||||
def is_guarded_value_immutable(self) -> bool: ...
|
||||
def is_tag_safe(self) -> bool: ...
|
||||
def is_tag_safe_root(self) -> bool: ...
|
||||
def has_no_accessors(self) -> bool: ...
|
||||
def has_object_aliasing_guard(self) -> bool: ...
|
||||
def get_type_of_guarded_value(self) -> type: ...
|
||||
def type_dict_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def type_mro_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def code_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def closure_manager(
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
# Leaf guards
|
||||
def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
|
||||
|
|
@ -106,7 +228,94 @@ class GuardManager:
|
|||
def add_torch_function_mode_stack_guard(
|
||||
self, initial_stack, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_mapping_keys_guard(self, value, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_dict_length_check_guard(
|
||||
self, value, verbose_code_parts: list[str]
|
||||
) -> None: ...
|
||||
def add_length_check_guard(self, value, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_true_match_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_false_match_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_none_match_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_not_none_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_dispatch_key_set_guard(
|
||||
self,
|
||||
dispatch_key,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_tensor_match_guard(
|
||||
self,
|
||||
value,
|
||||
sizes,
|
||||
strides,
|
||||
tensor_name,
|
||||
verbose_code_parts: list[str],
|
||||
ptype,
|
||||
dispatch_keys,
|
||||
) -> None: ...
|
||||
def add_dynamic_indices_guard(
|
||||
self,
|
||||
value,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_no_hasattr_guard(
|
||||
self,
|
||||
attr_name,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_dict_contains_guard(
|
||||
self,
|
||||
contains,
|
||||
key,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_type_match_guard(
|
||||
self,
|
||||
value,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_dict_version_guard(
|
||||
self,
|
||||
value,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_set_contains_guard(
|
||||
self,
|
||||
contains,
|
||||
item,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_tuple_iterator_length_guard(
|
||||
self,
|
||||
length,
|
||||
type_id,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_range_iterator_match_guard(
|
||||
self,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
type_id,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_default_device_guard(
|
||||
self,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def mark_tag_safe(self) -> None: ...
|
||||
def mark_tag_safe_root(self) -> None: ...
|
||||
|
||||
class RootGuardManager(GuardManager):
|
||||
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
|
||||
|
|
@ -118,6 +327,7 @@ class RootGuardManager(GuardManager):
|
|||
def clone_manager(
|
||||
self, clone_filter_fn: Callable[[GuardManager], bool]
|
||||
) -> RootGuardManager: ...
|
||||
def attach_compile_id(self, compile_id: str) -> None: ...
|
||||
|
||||
class DictGuardManager(GuardManager):
|
||||
def get_key_manager(
|
||||
|
|
@ -134,6 +344,9 @@ class DictGuardManager(GuardManager):
|
|||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def get_key_value_managers(
|
||||
self,
|
||||
) -> dict[int, tuple[GuardManager, GuardManager]]: ...
|
||||
|
||||
# Guard accessor stubs
|
||||
class GuardAccessor: ...
|
||||
|
|
@ -146,8 +359,8 @@ class GetAttrGuardAccessor(GuardAccessor):
|
|||
def get_attr_name(self) -> str: ...
|
||||
|
||||
def install_object_aliasing_guard(
|
||||
guard_managers: list[GuardManager],
|
||||
tensor_names: list[str],
|
||||
x: GuardManager,
|
||||
y: GuardManager,
|
||||
verbose_code_parts: list[str],
|
||||
): ...
|
||||
def install_no_tensor_aliasing_guard(
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -31,7 +31,7 @@ import re
|
|||
import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from dataclasses import dataclass, field as dc_field
|
||||
from types import CodeType
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
|
|
@ -57,6 +57,7 @@ from torch._guards import (
|
|||
)
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.export.dynamic_shapes import _ConstraintTarget
|
||||
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
|
|
@ -388,7 +389,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
compiler_fn: Optional[CompilerFn],
|
||||
root_tx: "InstructionTranslatorBase",
|
||||
export: bool,
|
||||
export_constraints: Any,
|
||||
export_constraints: Sequence[_ConstraintTarget],
|
||||
frame_state: Any,
|
||||
local_scope: Scope,
|
||||
global_scope: Scope,
|
||||
|
|
@ -414,7 +415,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
# de-duplicate graph inputs by source and reuse the tracker
|
||||
self.input_source_to_var: dict[Source, VariableTracker] = {}
|
||||
self.export = export
|
||||
self.export_constraints = export_constraints
|
||||
self.export_constraints = export_constraints # type: ignore[assignment]
|
||||
self.frame_state = frame_state
|
||||
self.cleanup_hooks: list[Callable[[], Any]] = []
|
||||
# compile_id is an id number for the current torch.compile
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ def debug_insert_nops(
|
|||
compiler_fn=None,
|
||||
root_tx=None, # type: ignore[arg-type]
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
export_constraints=[],
|
||||
frame_state={"_id": 0},
|
||||
# TODO: shouldn't this be f_locals/f_globals from frame?
|
||||
local_scope=locals(),
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ class Guard:
|
|||
guard_types: Optional[list[str]] = None
|
||||
code_list: Optional[list[str]] = None
|
||||
obj_weakref: Optional[object] = None
|
||||
guarded_class_weakref: Optional[type] = None
|
||||
guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None
|
||||
|
||||
stack: Optional[CapturedTraceback] = None
|
||||
user_stack: Optional[traceback.StackSummary] = None
|
||||
|
|
@ -380,7 +380,7 @@ class Guard:
|
|||
def set_export_info(
|
||||
self,
|
||||
guard_type: str,
|
||||
guarded_class: Optional[type],
|
||||
guarded_class: Optional[weakref.ReferenceType[Any]],
|
||||
code_list: list[str],
|
||||
obj_weakref: object,
|
||||
) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user