import enum from collections.abc import Callable from typing import Any, Optional, TypeAlias import torch # TODO: We should move the `GuardManagerType` # defined in `guards.py` here and update other # imports GuardManagerType: TypeAlias = enum.Enum class GlobalStateGuard: def check(self) -> bool: ... def reason(self) -> str: ... class LeafGuard: def verbose_code_parts(self) -> list[str]: ... class RelationalGuard: ... class GuardDebugInfo: verbose_code_parts: list[str] result: bool num_guards_executed: int class GuardManager: def check(self, value: Any) -> bool: ... def check_verbose(self, value: Any) -> GuardDebugInfo: ... # Accessors def globals_dict_manager( self, f_globals: dict[str, Any], source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def framelocals_manager( self, key: tuple[str, int], source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def dict_getitem_manager( self, key: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def grad_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def generic_getattr_manager( self, attr: str, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def getitem_manager( self, key: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def get_generic_dict_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def list_getitem_manager( self, key: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def tuple_getitem_manager( self, key: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def set_getitem_manager( self, index: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def func_defaults_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def func_kwdefaults_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def tuple_iterator_getitem_manager( self, index: Any, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def weakref_call_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def call_function_no_args_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def global_weakref_manager( self, global_name: str, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def type_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def getattr_manager( self, attr: str, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def tensor_property_size_manager( self, idx: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def tensor_property_shape_manager( self, idx: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def tensor_property_storage_offset_manager( self, idx: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def indexed_manager( self, idx: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def lambda_manager( self, python_lambda: Callable[..., Any], source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def get_root(self) -> RootGuardManager: ... def get_source(self) -> str: ... def fail_count(self) -> int: ... def get_child_managers(self) -> list[GuardManager]: ... def repr(self) -> str: ... def type_of_guarded_value(self) -> str: ... def get_leaf_guards(self) -> list[LeafGuard]: ... def get_accessors(self) -> list[GuardManager]: ... def is_guarded_value_immutable(self) -> bool: ... def is_tag_safe(self) -> bool: ... def is_tag_safe_root(self) -> bool: ... def has_no_accessors(self) -> bool: ... def has_object_aliasing_guard(self) -> bool: ... def get_type_of_guarded_value(self) -> type: ... def type_dict_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def type_mro_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def code_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def closure_manager( self, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... # Leaf guards def add_lambda_guard( self, user_lambda: Callable[..., Any], verbose_code_parts: list[str] ) -> None: ... def add_id_match_guard( self, id_val: int, verbose_code_parts: list[str] ) -> None: ... def add_equals_match_guard( self, equals_val: Any, verbose_code_parts: list[str], ) -> None: ... def add_global_state_guard( self, initial_state: Any, verbose_code_parts: list[str] ) -> None: ... def add_torch_function_mode_stack_guard( self, initial_stack: list[Any], verbose_code_parts: list[str] ) -> None: ... def add_mapping_keys_guard( self, value: Any, verbose_code_parts: list[str] ) -> None: ... def add_dict_length_check_guard( self, value: int, verbose_code_parts: list[str] ) -> None: ... def add_length_check_guard( self, value: int, verbose_code_parts: list[str] ) -> None: ... def add_true_match_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_false_match_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_none_match_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_not_none_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_dispatch_key_set_guard( self, dispatch_key: Any, verbose_code_parts: list[str], ) -> None: ... def add_tensor_match_guard( self, value: Any, sizes: list[int], strides: list[int], tensor_name: str, verbose_code_parts: list[str], ptype: Any, dispatch_keys: Any, ) -> None: ... def add_dynamic_indices_guard( self, value: set[Any], verbose_code_parts: list[str], ) -> None: ... def add_no_hasattr_guard( self, attr_name: str, verbose_code_parts: list[str], ) -> None: ... def add_dict_contains_guard( self, contains: bool, key: Any, verbose_code_parts: list[str], ) -> None: ... def add_type_match_guard( self, value: int, verbose_code_parts: list[str], ) -> None: ... def add_dict_version_guard( self, value: Any, verbose_code_parts: list[str], ) -> None: ... def add_set_contains_guard( self, contains: bool, item: Any, verbose_code_parts: list[str], ) -> None: ... def add_dual_level_match_guard( self, level: int, verbose_code_parts: list[str], ) -> None: ... def add_float_is_nan_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_complex_is_nan_guard( self, verbose_code_parts: list[str], ) -> None: ... def add_tuple_iterator_length_guard( self, length: int, type_id: int, verbose_code_parts: list[str], ) -> None: ... def add_range_iterator_match_guard( self, start: int, stop: int, step: int, type_id: int, verbose_code_parts: list[str], ) -> None: ... def add_default_device_guard( self, verbose_code_parts: list[str], ) -> None: ... def mark_tag_safe(self) -> None: ... def mark_tag_safe_root(self) -> None: ... class RootGuardManager(GuardManager): def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... def add_epilogue_lambda_guard( self, guard: LeafGuard, verbose_code_parts: list[str], ) -> None: ... def clone_manager( self, clone_filter_fn: Callable[[GuardManager], bool] ) -> RootGuardManager: ... def attach_compile_id(self, compile_id: str) -> None: ... class DictGuardManager(GuardManager): def get_key_manager( self, index: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def get_value_manager( self, index: int, source: str, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: ... def get_key_value_managers( self, ) -> dict[int, tuple[GuardManager, GuardManager]]: ... # Guard accessor stubs class GuardAccessor: ... class DictGetItemGuardAccessor(GuardAccessor): ... class GetGenericDictGuardAccessor(GuardAccessor): ... class TypeDictGuardAccessor(GuardAccessor): ... class TypeMROGuardAccessor(GuardAccessor): ... class ClosureGuardAccessor(GuardAccessor): ... class TupleGetItemGuardAccessor(GuardAccessor): ... class TypeGuardAccessor(GuardAccessor): ... class CodeGuardAccessor(GuardAccessor): ... class FuncDefaultsGuardAccessor(GuardAccessor): ... class FuncKwDefaultsGuardAccessor(GuardAccessor): ... class GetAttrGuardAccessor(GuardAccessor): def get_attr_name(self) -> str: ... def install_object_aliasing_guard( x: GuardManager, y: GuardManager, verbose_code_parts: list[str], ) -> None: ... def install_no_tensor_aliasing_guard( guard_managers: list[GuardManager], tensor_names: list[str], verbose_code_parts: list[str], ) -> None: ... def install_storage_overlapping_guard( overlapping_guard_managers: list[GuardManager], non_overlapping_guard_managers: list[GuardManager], verbose_code_parts: list[str], ) -> None: ... def install_symbolic_shape_guard( guard_managers: list[GuardManager], nargs_int: int, nargs_float: int, py_addr: int, py_addr_keep_alive: Any, verbose_code_parts: list[str], ) -> None: ... def profile_guard_manager( guard_manager: GuardManager, f_locals: dict[str, Any], n_iters: int, ) -> float: ... class TensorGuards: def __init__( self, *, dynamic_dims_sizes: list[torch.SymInt | None] | None = None, dynamic_dims_strides: list[torch.SymInt | None] | None = None, ) -> None: ... def check(self, *args: Any) -> bool: ... def check_verbose( self, *args: Any, tensor_check_names: Optional[list[str]] = None ) -> bool | str: ... def assert_size_stride( item: torch.Tensor, size: torch.types._size, stride: torch.types._size, op_name: str | None = None, ) -> None: ... def assert_alignment( item: torch.Tensor, alignment: int, op_name: str | None = None, ) -> None: ... def check_obj_id(obj: object, expected: int) -> bool: ... def check_type_id(obj: object, expected: int) -> bool: ... def dict_version(d: dict[Any, Any]) -> int: ... def compute_overlapping_tensors( tensors: list[torch.Tensor], symbolic: bool = True ) -> set[int]: ...