From 4a7bc1d522bed4f3f792b4641372ec3d06b79b5d Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 31 Oct 2025 20:42:23 +0000 Subject: [PATCH] [BE][Typing][Dynamo] Type misc files in `torch/_dynamo/variables/` (#166569) Provides type coverage to ~3000 LOC and 200 methods in `torch/_dynamo/variables/` This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...) ### Coverage report: ``` mypy torch_dynamo/variables --linecount-report /tmp/coverage_log ``` Compare before to after - we go from 3826 to 7221 lines covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569 Approved by: https://github.com/williamwen42, https://github.com/Skylion007 --- torch/_dynamo/codegen.py | 8 +- torch/_dynamo/output_graph.py | 3 +- torch/_dynamo/side_effects.py | 3 +- torch/_dynamo/source.py | 11 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/variables/base.py | 115 ++++---- torch/_dynamo/variables/constant.py | 55 ++-- torch/_dynamo/variables/distributed.py | 99 ++++--- torch/_dynamo/variables/iter.py | 116 ++++---- torch/_dynamo/variables/optimizer.py | 77 +++-- torch/_dynamo/variables/script_object.py | 43 ++- torch/_dynamo/variables/sdpa.py | 26 +- torch/_dynamo/variables/streams.py | 1 + torch/_dynamo/variables/torch.py | 1 + torch/_dynamo/variables/torch_function.py | 330 +++++++++++++--------- 15 files changed, 515 insertions(+), 375 deletions(-) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 3a933f3de34..1861b201052 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -153,7 +153,7 @@ class PyCodegen: self.clear_tos() def __call__( - self, value: Union[VariableTracker, Source], allow_cache: bool = True + self, value: Union[VariableTracker, Source, None], allow_cache: bool = True ) -> None: """ Generate code such that top-of-stack (TOS) is set to value. @@ -188,7 +188,7 @@ class PyCodegen: value to handle aliasing (check side_effects.py and search for allow_cache=False). - b) If value.source is None, this is not allowed. TODO - assert this. + b) If value.source is None, this is not allowed Notable effects: 1. `self.top_of_stack` will be set to `value`, if we don't codegen @@ -197,6 +197,7 @@ class PyCodegen: `top_of_stack` or cached `tempvars`, or (b). `value` has special VT types like `NNModuleVariable`, etc. """ + assert value is not None if isinstance(value, Source): # If the source needs to be overridden, use the new one. source = self.overridden_sources.get(value, value) @@ -289,7 +290,8 @@ class PyCodegen: self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append( self.create_load_global( - value.global_mangled_class_name(self.tx), add=True + value.global_mangled_class_name(self.tx), # type: ignore[arg-type] + add=True, ) ) output.extend(create_call_function(2, False)) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index dbdc1a12096..94ce01c01be 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1303,6 +1303,7 @@ class OutputGraph(OutputGraphCommon): # A small codegen optimization because we might have different # VariableTrackers that share the same source. + assert x.source is not None list_idx = x.source.index # type: ignore[attr-defined] if list_idx not in visited: alias_name = self.new_var( @@ -1321,6 +1322,7 @@ class OutputGraph(OutputGraphCommon): ) # operate on alias, handled by suffix codegen + assert x.source is not None old_source = x.source overridden_sources[old_source] = LocalSource(visited[list_idx]) @@ -1864,7 +1866,6 @@ class OutputGraph(OutputGraphCommon): and isinstance(var.value, _ExportModuleSpecTrackerDict) ): potential_side_effects.append(var) - side_effect_refs = [ _get_source_debug_name(var.source) for var in potential_side_effects ] diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index c8df353406c..bd38e9295a0 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -258,6 +258,7 @@ class SideEffects: "Dynamo needs to fully exhaust the generator, which may cause " "unintended variable modifications." ) + assert item.mutation_type is not None if not is_side_effect_safe(item.mutation_type): # TODO plumb HOP information here unimplemented_v2( @@ -373,7 +374,7 @@ class SideEffects: if self.is_attribute_mutation(item): return item in self.store_attr_mutations - + assert item.mutation_type is not None return item.mutation_type.is_modified # type: ignore[attr-defined] def _track_obj( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 9fb4f32d68a..92c6875a88d 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -111,11 +111,14 @@ def is_constant_source(source: Source) -> bool: return False -def _get_source_debug_name(source: Source) -> str: - try: - return source.name() - except NotImplementedError: +def _get_source_debug_name(source: Optional[Source]) -> str: + if source is None: return "" + else: + try: + return source.name() + except NotImplementedError: + return "" @dataclasses.dataclass(frozen=True) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 0cc6a78dff9..4a62417540b 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -5201,7 +5201,7 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): ): if isinstance(val, ConstantVariable) and val.value is None: try: - val = tos.next_variable(self) + val = tos.next_variable(self) # type: ignore[arg-type] except (StopIteration, exc.ObservedUserStopIteration) as ex: # To implement SEND, we have to look at the implementation # when the iterator returns StopIteration. This translates to this code diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 731c29a365a..0abf2cc91e7 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Core variable tracking functionality for Dynamo. This module defines the fundamental classes and systems used to track and manage variables during Dynamo's operation. @@ -18,7 +16,10 @@ computations. import collections from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from enum import Enum -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, NoReturn, Optional, TYPE_CHECKING + +from torch._guards import Guard +from torch.fx.proxy import Node from .. import graph_break_hints, variables from ..current_scope_id import current_scope_id @@ -30,7 +31,7 @@ from ..utils import cmp_name_to_op_mapping, istype if TYPE_CHECKING: from ..codegen import PyCodegen - from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase + from ..symbolic_convert import InstructionTranslator class SourceType(Enum): @@ -115,10 +116,10 @@ class ValueMutationNew(MutationType): def __init__(self) -> None: super().__init__(SourceType.New) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self is other @@ -139,7 +140,7 @@ class ValueMutationExisting(MutationType): # filter out which pre-existing values it needs to generate mutation for. is_modified: bool - def __init__(self, is_modified: bool = False): + def __init__(self, is_modified: bool = False) -> None: super().__init__(SourceType.Existing) self.is_modified = is_modified @@ -150,7 +151,7 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ - def __init__(self, typ: SourceType): + def __init__(self, typ: SourceType) -> None: super().__init__(typ) @@ -166,7 +167,7 @@ class AttributeMutationExisting(AttributeMutation): be used afterwards in Python. """ - def __init__(self): + def __init__(self) -> None: super().__init__(SourceType.Existing) @@ -182,16 +183,16 @@ class AttributeMutationNew(AttributeMutation): the Python world. """ - def __init__(self, cls_source: Optional[Source] = None): + def __init__(self, cls_source: Optional[Source] = None) -> None: super().__init__(SourceType.New) self.cls_source = cls_source -def _is_top_level_scope(scope_id): +def _is_top_level_scope(scope_id: int) -> bool: return scope_id == 1 -def is_side_effect_safe(m: MutationType): +def is_side_effect_safe(m: MutationType) -> bool: scope_id = current_scope_id() # In the top-level scope (if no HigherOrderOperators are involved), @@ -209,15 +210,15 @@ def is_side_effect_safe(m: MutationType): class AsPythonConstantNotImplementedError(NotImplementedError): vt: "VariableTracker" - def __init__(self, vt: "VariableTracker"): + def __init__(self, vt: "VariableTracker") -> None: super().__init__(f"{vt} is not a constant") self.vt = vt class VariableTrackerMeta(type): - all_subclasses = [] + all_subclasses: list[type] = [] - def __instancecheck__(cls, instance) -> bool: + def __instancecheck__(cls: type, instance: object) -> bool: """Make isinstance work with LazyVariableTracker""" # This is super expensive - just having it costs over 4% of tracing # time! @@ -227,8 +228,10 @@ class VariableTrackerMeta(type): instance = instance.realize() return type.__instancecheck__(cls, instance) - def __init__(cls, name, bases, attrs) -> None: - super().__init__(name, bases, attrs) + def __init__( + cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: + super().__init__(name, bases, attrs) # type: ignore[misc] VariableTrackerMeta.all_subclasses.append(cls) @@ -252,7 +255,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): "user_code_variable_name", } - def clone(self, **kwargs): + def clone(self, **kwargs: Any) -> "VariableTracker": """Shallow copy with some (optional) changes""" args = dict(self.__dict__) args.update(kwargs) @@ -295,14 +298,14 @@ class VariableTracker(metaclass=VariableTrackerMeta): def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def debug_repr(self): + def debug_repr(self) -> str: # Intended to be overridden to provide more info try: return repr(self.as_python_constant()) except NotImplementedError: return repr(self) - def python_type(self): + def python_type(self) -> type: """ Abstract method to be implemented by subclasses of VariableTracker. @@ -331,17 +334,17 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: raise NotImplementedError(f"{self} has no type") from None - def python_type_name(self): + def python_type_name(self) -> str: try: return self.python_type().__name__ except NotImplementedError: return "" - def as_python_constant(self): + def as_python_constant(self) -> Any: """For constants""" raise AsPythonConstantNotImplementedError(self) - def guard_as_python_constant(self): + def guard_as_python_constant(self) -> Any: """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" try: return self.as_python_constant() @@ -353,18 +356,20 @@ class VariableTracker(metaclass=VariableTrackerMeta): hints=[], ) - def is_python_constant(self): + def is_python_constant(self) -> bool: try: self.as_python_constant() return True except NotImplementedError: return False - def make_guard(self, fn): + def make_guard(self, fn: Callable[..., Any]) -> Guard: if self.source: return self.source.make_guard(fn) raise NotImplementedError + # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase` + # and cascade that (large blast radius) def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: """getattr(self, name) returning a python constant""" raise NotImplementedError @@ -381,17 +386,17 @@ class VariableTracker(metaclass=VariableTrackerMeta): install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) - def is_proxy(self): + def is_proxy(self) -> bool: try: self.as_proxy() return True except NotImplementedError: return False - def as_proxy(self): + def as_proxy(self) -> Any: raise NotImplementedError(str(self)) - def maybe_fx_node(self): + def maybe_fx_node(self) -> Optional[Node]: try: proxy = self.as_proxy() import torch.fx @@ -402,13 +407,13 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: return None - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: raise NotImplementedError - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: raise NotImplementedError - def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: # like unpack_var_sequence, but should only be used when it is # safe to eagerly (vs. lazily) unpack this variable. # e.g. map(f, x) is normally evaluated lazily but sometimes @@ -417,7 +422,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): # it should only be called once. return self.unpack_var_sequence(tx) - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: Any) -> bool: try: self.unpack_var_sequence(tx) return True @@ -425,13 +430,15 @@ class VariableTracker(metaclass=VariableTrackerMeta): return False # NB: don't call force_unpack_var_sequence, especially if it mutates! - def has_force_unpack_var_sequence(self, tx) -> bool: + def has_force_unpack_var_sequence(self, tx: Any) -> bool: return self.has_unpack_var_sequence(tx) # Forces unpacking the var sequence while also applying a function to each element. # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: Any, fn: Callable[["VariableTracker"], Any] + ) -> None: assert self.has_force_unpack_var_sequence(tx) for v in self.unpack_var_sequence(tx): fn(v) @@ -444,9 +451,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): hints=[], ) - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", @@ -459,9 +464,9 @@ class VariableTracker(metaclass=VariableTrackerMeta): def call_function( self, - tx: "InstructionTranslator", + tx: Any, args: Sequence["VariableTracker"], - kwargs: "dict[str, VariableTracker]", + kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported function call", @@ -475,10 +480,10 @@ class VariableTracker(metaclass=VariableTrackerMeta): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", + tx: Any, + name: str, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__len__" and self.has_unpack_var_sequence(tx): assert not (args or kwargs) @@ -562,7 +567,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): hints=hints, ) - def set_name_hint(self, name): + def set_name_hint(self, name: str) -> None: pass def realize(self) -> "VariableTracker": @@ -573,11 +578,11 @@ class VariableTracker(metaclass=VariableTrackerMeta): """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" return self - def is_realized(self): + def is_realized(self) -> bool: """Used by LazyVariableTracker to indicate an unrealized node""" return True - def next_variable(self, tx): + def next_variable(self, tx: Any) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported next() call", context=f"next({self})", @@ -585,20 +590,20 @@ class VariableTracker(metaclass=VariableTrackerMeta): hints=[*graph_break_hints.USER_ERROR], ) - def is_strict_mode(self, tx): - return tx.strict_checks_fn and tx.strict_checks_fn(self) + def is_strict_mode(self, tx: Any) -> bool: + return bool(tx.strict_checks_fn and tx.strict_checks_fn(self)) - def is_mutable(self): + def is_mutable(self) -> bool: """Whether Dynamo allows mutation on this variable.""" return not self.is_immutable() - def is_immutable(self): + def is_immutable(self) -> bool: """Whether Dynamo bans mutation on this variable.""" return self.mutation_type is None @staticmethod def build( - tx: "InstructionTranslatorBase", + tx: Any, value: Any, source: Optional[Source] = None, ) -> Any: @@ -611,8 +616,8 @@ class VariableTracker(metaclass=VariableTrackerMeta): def __init__( self, *, - source: Source = None, - mutation_type: MutationType = None, + source: Optional[Source] = None, + mutation_type: Optional[MutationType] = None, ) -> None: super().__init__() self.source = source @@ -636,12 +641,12 @@ class VariableTracker(metaclass=VariableTrackerMeta): assert source is not None -def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None: +def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn: msg = variables.ConstantVariable.create(msg_str) raise_observed_exception(TypeError, tx, args=[msg]) -def typestr(*objs): +def typestr(*objs: object) -> str: if len(objs) == 1: (obj,) = objs if isinstance(obj, VariableTracker): diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index b7d150bfab3..afe445514eb 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Constant and enum variable tracking in Dynamo. @@ -8,8 +6,9 @@ values during compilation, ensuring proper handling of Python literals and maintaining type safety through the compilation process. """ +import enum import operator -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING, Union import torch from torch._dynamo.source import AttrSource, GetItemSource @@ -40,7 +39,7 @@ class ConstantVariable(VariableTracker): """ @staticmethod - def create(value, **kwargs) -> VariableTracker: + def create(value: Any, **kwargs: Any) -> VariableTracker: """ Create a `ConstantVariable` based on the given value, and supports automatic routing for collection types like `tuple` (in which case we'd @@ -76,7 +75,7 @@ class ConstantVariable(VariableTracker): return ConstantVariable(value, **kwargs) - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) assert ConstantVariable.is_base_literal(value), f""" Cannot construct `ConstantVariable` for value of type {type(value)}. @@ -92,48 +91,52 @@ its type to `common_constant_types`. else: self.value = value - def as_proxy(self): + def as_proxy(self) -> Any: return self.value def __repr__(self) -> str: return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value - def is_python_constant(self): + def is_python_constant(self) -> Literal[True]: return True @property - def items(self): + def items(self) -> list[VariableTracker]: """ Need this when adding a BaseListVariable and a ConstantVariable together. Happens in detectron2. """ return self.unpack_var_sequence(tx=None) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: return ConstantVariable.create( self.value[arg.as_python_constant()], ) @staticmethod - def is_base_literal(obj): + def is_base_literal(obj: object) -> bool: return type(obj) in common_constant_types @staticmethod - def is_literal(obj): + def is_literal(obj: object) -> bool: if type(obj) in (list, tuple, set, frozenset, torch.Size): - return all(ConstantVariable.is_literal(x) for x in obj) + return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined] return ConstantVariable.is_base_literal(obj) - def unpack_var_sequence(self, tx): + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] + ) -> list[VariableTracker]: try: return [ConstantVariable.create(x) for x in self.as_python_constant()] except TypeError as e: raise NotImplementedError from e - def const_getattr(self, tx: "InstructionTranslator", name): + def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if not hasattr(self.value, name): raise_observed_exception(AttributeError, tx, args=[name]) member = getattr(self.value, name) @@ -144,10 +147,10 @@ its type to `common_constant_types`. def call_method( self, tx: "InstructionTranslator", - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from .tensor import SymNodeVariable if name == "format" and istype(self.value, str): @@ -262,7 +265,7 @@ its type to `common_constant_types`. def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: result = hasattr(self.value, name) return variables.ConstantVariable.create(result) @@ -274,12 +277,14 @@ class EnumVariable(VariableTracker): both standard Enum and IntEnum with proper value tracking and comparison. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value @classmethod - def create(cls, cls_type, value_vt, options): + def create( + cls, cls_type: Any, value_vt: VariableTracker, options: Any + ) -> "EnumVariable": if isinstance(value_vt, variables.ConstantVariable): for member in list(cls_type): if member.value == value_vt.as_python_constant(): @@ -293,7 +298,7 @@ class EnumVariable(VariableTracker): hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], ) - def as_proxy(self): + def as_proxy(self) -> Union[enum.Enum, int]: if isinstance(self.value, int): return int(self.value) # convert IntEnum to a normal int return self.value @@ -301,10 +306,10 @@ class EnumVariable(VariableTracker): def __repr__(self) -> str: return f"EnumVariable({type(self.value)})" - def as_python_constant(self): + def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]: return self.value - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if not hasattr(self.value, name): raise NotImplementedError if name in cmp_name_to_op_mapping: diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 37878abbb37..eb39dd8fa3e 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Distributed computing variable tracking classes for PyTorch Dynamo. @@ -22,7 +20,7 @@ checks and proper tracking of distributed state and operations across processes. import functools import inspect -from typing import TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState @@ -40,6 +38,7 @@ from .constant import ConstantVariable, EnumVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -54,7 +53,7 @@ class DistributedVariable(VariableTracker): and hold the tracking value for the corresponding distributed object. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) if not DistributedVariable.is_available(): unimplemented_v2( @@ -67,16 +66,16 @@ class DistributedVariable(VariableTracker): ) self.value = value - def python_type(self): + def python_type(self) -> type: return type(self.value) @staticmethod - def is_available(): + def is_available() -> bool: # check if the distributed package is available or not return torch.distributed.is_available() -def is_from_local(value): +def is_from_local(value: object) -> bool: if not DistributedVariable.is_available(): return False from torch.distributed.tensor import DTensor @@ -84,7 +83,7 @@ def is_from_local(value): return inspect.isfunction(value) and value is DTensor.from_local -def is_constant_pg_functions(value): +def is_constant_pg_functions(value: object) -> bool: if not DistributedVariable.is_available(): return False @@ -114,7 +113,7 @@ class WorldMetaClassVariable(DistributedVariable): """ @classmethod - def is_group_member_type(cls, value): + def is_group_member_type(cls, value: object) -> bool: if not cls.is_available(): return False @@ -124,10 +123,12 @@ class WorldMetaClassVariable(DistributedVariable): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "WORLD": + assert self.source source = AttrSource(base=self.source, member="WORLD") install_guard(source.make_guard(GuardBuilder.ID_MATCH)) return ProcessGroupVariable(self.value.WORLD) elif name == "NON_GROUP_MEMBER": + assert self.source source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") install_guard(source.make_guard(GuardBuilder.ID_MATCH)) return EnumVariable(self.value.NON_GROUP_MEMBER) @@ -136,7 +137,7 @@ class WorldMetaClassVariable(DistributedVariable): class PlacementClassVariable(DistributedVariable): @staticmethod - def is_placement_type(value): + def is_placement_type(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -145,15 +146,15 @@ class PlacementClassVariable(DistributedVariable): return isinstance(value, type) and issubclass(value, Placement) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.source: # NOTE: we don't need to track mutations to the placement class as they # are supposed to be immutable. @@ -168,16 +169,15 @@ class PlacementClassVariable(DistributedVariable): class PlacementVariable(DistributedVariable): @staticmethod - def is_placement(value): + def is_placement(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False - from torch.distributed.tensor.placement_types import Placement return isinstance(value, Placement) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -187,11 +187,11 @@ class PlacementVariable(DistributedVariable): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from . import ConstantVariable # Placement types dynamo tracking only allows following methods @@ -228,15 +228,16 @@ class PlacementVariable(DistributedVariable): args = [x.as_python_constant() for x in args] kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + assert method is not None if name == "__setattr__": method(self.value, *args, **kwargs) return self constant_val = method(self.value, *args, **kwargs) return ConstantVariable.create(constant_val) - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type] - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: # Reconstruct the Placement object by calling its constructor # e.g., Shard(0), Replicate(), Partial() from torch.distributed.tensor.placement_types import Partial, Replicate, Shard @@ -263,7 +264,7 @@ class PlacementVariable(DistributedVariable): class DeviceMeshVariable(DistributedVariable): @staticmethod - def is_device_mesh(value): + def is_device_mesh(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -272,7 +273,7 @@ class DeviceMeshVariable(DistributedVariable): return istype(value, DeviceMesh) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -289,11 +290,11 @@ class DeviceMeshVariable(DistributedVariable): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "size": const_args = [x.as_python_constant() for x in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} @@ -338,16 +339,16 @@ class ProcessGroupVariable(DistributedVariable): or just graph-break whenever one of our special cases is not hit? """ - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "rank": return variables.ConstantVariable.create(self.value.rank()) if name == "size": @@ -357,7 +358,7 @@ class ProcessGroupVariable(DistributedVariable): return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "group_name": return variables.ConstantVariable.create(self.value.group_name) if name in ["rank", "size"]: @@ -368,7 +369,7 @@ class ProcessGroupVariable(DistributedVariable): return super().var_getattr(tx, name) @staticmethod - def is_process_group(value): + def is_process_group(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -386,11 +387,11 @@ class BackwardHookVariable(VariableTracker): @staticmethod def create( - tx, + tx: "InstructionTranslator", module: VariableTracker, user_hooks: VariableTracker, user_pre_hooks: VariableTracker, - ): + ) -> "BackwardHookVariable": if not compiled_autograd.compiled_autograd_enabled: unimplemented_v2( gb_type="Module-level backwards hooks require compiled autograd.", @@ -401,7 +402,9 @@ class BackwardHookVariable(VariableTracker): ], ) - def _in_graph_bw_hooks(bw_state: BackwardState): + def _in_graph_bw_hooks( + bw_state: BackwardState, + ) -> torch.utils.hooks.BackwardHook: """ Rather than installing the user hooks in the graph (which don't survive AotAutograd), we install hooks that will call @@ -448,7 +451,7 @@ class BackwardHookVariable(VariableTracker): module: VariableTracker, user_hooks: VariableTracker, user_pre_hooks: VariableTracker, - **options, + **options: Any, ) -> None: super().__init__(**options) self.proxy = proxy @@ -456,13 +459,13 @@ class BackwardHookVariable(VariableTracker): self.user_hooks = user_hooks self.user_pre_hooks = user_pre_hooks - def as_proxy(self): + def as_proxy(self) -> torch.fx.Proxy: return self.proxy def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: @@ -470,7 +473,9 @@ class BackwardHookVariable(VariableTracker): return self._setup_hook(tx, name, *args, **kwargs) return super().call_method(tx, name, args, kwargs) - def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args): + def _setup_hook( + self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker + ) -> VariableTracker: from .builder import wrap_fx_proxy return wrap_fx_proxy( diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index f790b65830b..b8b7ca8b9fd 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module provides iterator-related variable tracking functionality for Dynamo. It implements variable classes for handling Python iterators and itertools functions @@ -16,7 +14,8 @@ handling of iterator operations during code transformation and optimization. """ import itertools -from typing import TYPE_CHECKING, Union +from collections.abc import Callable +from typing import Any, Sequence, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( @@ -45,20 +44,20 @@ MAX_ITERATOR_LIMIT = 100 * 1024 # 100k class ItertoolsVariable(VariableTracker): - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value def __repr__(self) -> str: return f"ItertoolsVariable({self.value})" - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", + args: Sequence["VariableTracker"], kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` @@ -111,7 +110,7 @@ class ItertoolsVariable(VariableTracker): hints=[*graph_break_hints.USER_ERROR], ) - def retrieve_const_key(key): + def retrieve_const_key(key: VariableTracker) -> Any: if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() elif isinstance(key, variables.ConstantVariable): @@ -144,14 +143,14 @@ class ItertoolsVariable(VariableTracker): if "key" in kwargs: - def keyfunc(x): + def keyfunc(x: VariableTracker) -> Any: return retrieve_const_key( - kwargs.get("key").call_function(tx, [x], {}) + kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr] ) else: - def keyfunc(x): + def keyfunc(x: VariableTracker) -> Any: return retrieve_const_key(x) result = [] @@ -219,10 +218,10 @@ class ItertoolsVariable(VariableTracker): class IteratorVariable(VariableTracker): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: unimplemented_v2( gb_type="Unimplemented next() call", context=f"next({self})", @@ -234,12 +233,16 @@ class IteratorVariable(VariableTracker): # Normally, iterators are accessed lazily. # Example of safe eager unpacking: list(map(f, seq)) # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) - def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: - result = [] + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] self.force_apply_to_var_sequence(tx, result.append) return result - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[Any], Any] + ) -> None: while True: try: fn(self.next_variable(tx)) @@ -249,7 +252,7 @@ class IteratorVariable(VariableTracker): # don't call force_unpack_var_sequence since it can mutate # IteratorVariable state! - def has_force_unpack_var_sequence(self, tx) -> bool: + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return True def call_obj_hasattr( @@ -257,12 +260,12 @@ class IteratorVariable(VariableTracker): ) -> "VariableTracker": if name == "__iter__" or name == "__next__": return variables.ConstantVariable.create(True) - super().call_obj_hasattr(tx, name) + return super().call_obj_hasattr(tx, name) def call_method( self, tx: "InstructionTranslator", - name, + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: @@ -287,12 +290,12 @@ class ObjectIteratorVariable(IteratorVariable): > list(b) # empty list """ - def __init__(self, obj: VariableTracker, **kwargs): + def __init__(self, obj: VariableTracker, **kwargs: Any) -> None: super().__init__(**kwargs) self.obj = obj self.generator_exhausted = False - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: if self.generator_exhausted: raise_observed_exception(StopIteration, tx) @@ -306,15 +309,15 @@ class ObjectIteratorVariable(IteratorVariable): class RepeatIteratorVariable(IteratorVariable): - def __init__(self, item: VariableTracker, **kwargs) -> None: + def __init__(self, item: VariableTracker, **kwargs: Any) -> None: super().__init__(**kwargs) self.item = item # Repeat needs no mutation, clone self - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: return self.item - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.extend_output( [ @@ -328,7 +331,12 @@ class RepeatIteratorVariable(IteratorVariable): class CountIteratorVariable(IteratorVariable): - def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: + def __init__( + self, + item: Union[int, VariableTracker] = 0, + step: Union[int, VariableTracker] = 1, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if not isinstance(item, VariableTracker): item = ConstantVariable.create(item) @@ -337,14 +345,14 @@ class CountIteratorVariable(IteratorVariable): self.item = item self.step = step - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() old_item = self.item tx.output.side_effects.mutation(self) self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.extend_output( [ @@ -373,7 +381,7 @@ class ZipVariable(IteratorVariable): self, iterables: list[VariableTracker], strict: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance(iterables, list) @@ -382,16 +390,18 @@ class ZipVariable(IteratorVariable): self.index = 0 self.strict = strict - def python_type(self): + def python_type(self) -> type[zip]: # type: ignore[type-arg] return zip - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return all( isinstance(it, list) or it.has_unpack_var_sequence(tx) for it in self.iterables ) - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) iterables = [] for it in self.iterables: @@ -403,7 +413,7 @@ class ZipVariable(IteratorVariable): zipped = zip(*iterables, **kwargs) return [variables.TupleVariable(list(var)) for var in zipped] - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() if len(self.iterables) == 0: @@ -412,7 +422,9 @@ class ZipVariable(IteratorVariable): old_index = self.index args = [] - def get_item(it): + def get_item( + it: Union[list[VariableTracker], VariableTracker], + ) -> VariableTracker: if isinstance(it, list): if old_index >= len(it): raise_observed_exception(StopIteration, tx) @@ -441,7 +453,7 @@ class ZipVariable(IteratorVariable): raise handle_observed_exception(tx) raise UserError( - ValueError, + ValueError, # type: ignore[arg-type] "zip() has one argument of len differing from others", ) from None raise @@ -450,7 +462,7 @@ class ZipVariable(IteratorVariable): self.index += 1 return variables.TupleVariable(args) - def reconstruct_items(self, codegen: "PyCodegen"): + def reconstruct_items(self, codegen: "PyCodegen") -> None: for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] @@ -459,7 +471,7 @@ class ZipVariable(IteratorVariable): else: codegen(it) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) @@ -483,23 +495,23 @@ class MapVariable(ZipVariable): def __init__( self, fn: VariableTracker, - iterables: list[Union[list[VariableTracker], VariableTracker]], - **kwargs, + iterables: list[VariableTracker], + **kwargs: Any, ) -> None: super().__init__(iterables, **kwargs) self.fn = fn - def python_type(self): + def python_type(self) -> type: return map - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return False - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: args = super().next_variable(tx) - return self.fn.call_function(tx, args.items, {}) + return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined] - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) @@ -526,23 +538,25 @@ class FilterVariable(IteratorVariable): def __init__( self, fn: VariableTracker, - iterable: Union[list[VariableTracker], VariableTracker], - **kwargs, + iterable: list[VariableTracker], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.fn = fn self.iterable = iterable self.index = 0 - def python_type(self): + def python_type(self) -> type: return filter - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( tx ) - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) it = None if isinstance(self.iterable, list): @@ -552,8 +566,8 @@ class FilterVariable(IteratorVariable): filtered = self.fn.call_function(tx, it, {}) return [variables.TupleVariable([filtered])] - def next_variable(self, tx): - def _next(): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + def _next() -> VariableTracker: old_index = self.index if isinstance(self.iterable, list): if old_index >= len(self.iterable): @@ -576,7 +590,7 @@ class FilterVariable(IteratorVariable): if pred_res.as_python_constant(): return item - def reconstruct_items(self, codegen: "PyCodegen"): + def reconstruct_items(self, codegen: "PyCodegen") -> None: if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) @@ -584,7 +598,7 @@ class FilterVariable(IteratorVariable): else: codegen(self.iterable) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 18f75833551..b64c099ee7e 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module implements variable tracking for PyTorch optimizers during Dynamo tracing. @@ -24,9 +22,11 @@ optimizer-specific optimizations and safety guarantees. import logging import weakref -from typing import TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING import torch +from torch._dynamo.variables.tensor import TensorVariable +from torch._guards import Source from torch._logging import getArtifactLogger from torch.utils._pytree import tree_map_only @@ -63,13 +63,14 @@ class GuardInstallException(Exception): perf_hint_log = getArtifactLogger(__name__, "perf_hints") -def _is_static_for_cudagraphs(x): +def _is_static_for_cudagraphs(x: torch.Tensor) -> bool: from torch._inductor.cudagraph_trees import get_manager if x.is_cuda: manager = get_manager(x.device.index, False) is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None if manager: + assert manager.current_node is not None return ( is_static_address or manager.current_node._is_cuda_graph_recorded_tensor(x) @@ -91,26 +92,30 @@ class OptimizerVariable(UserDefinedObjectVariable): def __init__( self, - value, - grad_to_source=None, - static_tensor_names=None, - tensor_to_source=None, - **kwargs, + value: torch.optim.Optimizer, + grad_to_source: Optional[dict[Any, GradSource]] = None, + static_tensor_names: Optional[set[str]] = None, + tensor_to_source: Optional[dict[torch.Tensor, Source]] = None, + **kwargs: Any, ) -> None: super().__init__(value, **kwargs) + self.value: torch.optim.Optimizer = value self.grad_to_source = grad_to_source or {} self.tensor_to_source = tensor_to_source or {} self.static_tensor_names = static_tensor_names or set() def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], ) -> "VariableTracker": """This is an optimization to avoid tracing the very slow initialization of the optimizer""" if name == "_init_group": + if not hasattr(self.value, "_init_group"): + # Fallback: if the optimizer does not have _init_group, trace normally + return super().call_method(tx, name, args, kwargs) try: self.graph_break_if_pending_mutation(tx) self.move_step_if_cpu() @@ -135,11 +140,12 @@ class OptimizerVariable(UserDefinedObjectVariable): return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: # Note: this allows us to intercept the call in call_method # in the typical case, we return a UserMethodVariable # which will directly inline if name in ("_init_group", "step"): + assert self.source return GetAttrVariable(self, name, source=AttrSource(self.source, name)) if name == "param_groups": @@ -153,7 +159,7 @@ class OptimizerVariable(UserDefinedObjectVariable): return super().var_getattr(tx, name) - def graph_break_if_pending_mutation(self, tx): + def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None: # If there are pending mutations on a parameter (due to using closure) # then we need to graph break to allow the python version of the parameter # to update, so that running _init_group will initialize the states with @@ -167,12 +173,12 @@ class OptimizerVariable(UserDefinedObjectVariable): raise Unsupported("Pending mutation on parameter") - def _set_capturable(self, tx): + def _set_capturable(self, tx: "InstructionTranslator") -> None: from . import LazyVariableTracker # We only set capturable if params are on cuda # and the state is not initialized - def safe_to_set_capturable(group): + def safe_to_set_capturable(group: dict[str, Any]) -> bool: all_uninitialized = True all_gpu = True @@ -199,10 +205,12 @@ class OptimizerVariable(UserDefinedObjectVariable): ) param_group_vt.items[key] = ConstantVariable.create(True) - def get_python_args(self, *args, **kwargs): + def get_python_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], dict[str, Any]]: """Get python values equivalent to the variable tracker args""" - def map_arg(arg): + def map_arg(arg: Any) -> Any: if isinstance(arg, ConstantVariable): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: @@ -227,19 +235,19 @@ class OptimizerVariable(UserDefinedObjectVariable): # if this is the case, move it to the GPU # corresponding to the parameter # in most cases this is a no-op because the state is empty - def move_step_if_cpu(self): + def move_step_if_cpu(self) -> None: for p, state in self.value.state.items(): if "step" in state and state["step"].is_cpu: state["step"] = state["step"].to(p.device) - def map_sources_and_install_guards(self, tx): + def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None: from ..decorators import mark_static_address from .lazy import LazyVariableTracker self.grad_to_source = {} self.tensor_to_source = {} - def mark_static(x): + def mark_static(x: Any) -> None: mark_static_address(x, guard=True) tree_map_only(torch.Tensor, mark_static, self.value.state) @@ -252,12 +260,12 @@ class OptimizerVariable(UserDefinedObjectVariable): ) state_source = self.source and AttrSource(self.source, "state") - state_vt = VariableTracker.build(tx, self.value.state, state_source) # We need to realize the top level state dict to populate # the guard locals state_vt.realize() + assert state_source is not None tx.output.guard_on_key_order.add(state_source) # Populate self.grad_to_source and self.tensor_to_source so that we can @@ -308,14 +316,14 @@ class OptimizerVariable(UserDefinedObjectVariable): # Note: to avoid spam logs only warn if perf hint artifact is enabled # (NB: artifacts are only enabled at the debug or warning level) if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): - non_static_grads = [src.name() for src in non_static_grads] + non_static_grad_names = [src.name() for src in non_static_grads] perf_hint_log.warning( ( "Grad tensors %s will be copied during cudagraphs execution." "If using cudagraphs and the grad tensor addresses will be the same across runs," " use torch._dynamo.decorators.mark_static_address to elide this copy.", ), - non_static_grads, + non_static_grad_names, ) # We have to again iterate over the state dict to collect the @@ -335,7 +343,9 @@ class OptimizerVariable(UserDefinedObjectVariable): p_state_source, ConstDictKeySource(p_state_source, inner_idx) ) - def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): + def wrap_tensor( + self, tx: "InstructionTranslator", tensor_value: torch.Tensor + ) -> TensorVariable: """Wrap state tensor in a TensorVariable""" from ..decorators import mark_static_address @@ -362,8 +372,13 @@ class OptimizerVariable(UserDefinedObjectVariable): return VariableTracker.build(tx, tensor_value, source) def update_list_args( - self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs - ): + self, + tx: "InstructionTranslator", + args: Iterable[VariableTracker], + kwargs: Any, + py_args: Iterable[Any], + py_kwargs: Any, + ) -> None: """Update the args and kwargs to the traced optimizer call""" for arg, py_arg in zip(args, py_args): if isinstance(arg, ListVariable): @@ -378,13 +393,13 @@ class OptimizerVariable(UserDefinedObjectVariable): source = arg.source and GetItemSource(arg.source, i) arg.items.append(VariableTracker.build(tx, val, source)) - def create_finalizer(self, tx): + def create_finalizer(self, tx: "InstructionTranslator") -> None: names_to_delete = self.static_tensor_names value = self.value tc = tx.output.tracing_context - def init_finalizer(gm): - def clear_static_tensor_refs(): + def init_finalizer(gm: torch.fx.GraphModule) -> None: + def clear_static_tensor_refs() -> None: for name in names_to_delete: gm._buffers.pop(name, None) gm._parameters.pop(name, None) diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index a120ab488ed..85977104977 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs - """ This module implements variable tracking for TorchScript objects during Dynamo tracing. @@ -22,8 +19,13 @@ by limiting operations to known-safe patterns and failing fast for unsafe usage. """ import functools +from collections.abc import Callable +from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec import torch +from torch._guards import Source +from torch.fx.proxy import Proxy from .. import graph_break_hints from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported @@ -31,10 +33,19 @@ from .base import VariableTracker from .user_defined import UserDefinedObjectVariable -def _raise_hard_error_if_graph_break(reason): - def deco(fn): +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _raise_hard_error_if_graph_break( + reason: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(fn) - def graph_break_as_hard_error(*args, **kwargs): + def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T: try: return fn(*args, **kwargs) except Unsupported as e: @@ -49,26 +60,26 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable): _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} @classmethod - def is_matching_cls(cls, user_cls: type): + def is_matching_cls(cls, user_cls: type) -> bool: return issubclass(user_cls, torch.ScriptObject) @staticmethod - def create(proxy, value, **options): + def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable": return TorchScriptObjectVariable(proxy, value, **options) - def __init__(self, proxy, value, source, **kwargs) -> None: + def __init__(self, proxy: Proxy, value: Any, source: Source, **kwargs: Any) -> None: super().__init__(value, **kwargs) self.proxy = proxy self.proxy.node.meta["example_value"] = value self.source = source - def as_proxy(self): + def as_proxy(self) -> Proxy: return self.proxy @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) - def var_getattr(self, tx, name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: from torch._higher_order_ops.torchbind import call_torchbind from ..source import AttrSource @@ -95,7 +106,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable): "Use method calls instead of attribute access.", ], ) - + assert self.source is not None return TorchHigherOrderOperatorVariable.make( call_torchbind, source=AttrSource(self.source, name), @@ -110,7 +121,13 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable): @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: unimplemented_v2( gb_type="Weird method call on TorchScript object", context=f"value={self.value}, method={name}", diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index e63edf8e2b0..75928842cf2 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,7 +1,9 @@ -# mypy: ignore-errors - from inspect import getattr_static -from typing import TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING, TypeGuard + +from torch._guards import Source +from torch.backends.cuda import SDPAParams +from torch.fx.proxy import Proxy from ..bytecode_transformation import create_call_function from ..exc import Unsupported @@ -29,9 +31,9 @@ class SDPAParamsVariable(VariableTracker): This is a read-only container.""" @staticmethod - def create(tx: "InstructionTranslator", value, source): - from torch.backends.cuda import SDPAParams - + def create( + tx: "InstructionTranslator", value: Any, source: Source + ) -> VariableTracker: from .torch import TorchInGraphFunctionVariable params = [ @@ -40,12 +42,14 @@ class SDPAParamsVariable(VariableTracker): ] return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) - def __init__(self, proxy, param_vars, **kwargs) -> None: + def __init__( + self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any + ) -> None: self.proxy = proxy self.param_vars = param_vars super().__init__(**kwargs) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: assert self.source is None assert self.param_vars is not None codegen.add_push_null( @@ -54,7 +58,7 @@ class SDPAParamsVariable(VariableTracker): codegen.foreach(self.param_vars) codegen.extend_output(create_call_function(len(self.param_vars), False)) - def as_proxy(self): + def as_proxy(self) -> Proxy: return self.proxy def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -80,7 +84,5 @@ class SDPAParamsVariable(VariableTracker): return wrap_fx_proxy(tx=tx, proxy=proxy) @staticmethod - def is_sdpa_params(value): - from torch.backends.cuda import SDPAParams - + def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]: return value is SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index dce4afe929c..a0da59d740c 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -232,6 +232,7 @@ class StreamVariable(StreamContextVariable): return ConstantVariable.create(NotImplemented) if other.source: + assert self.source is not None install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) return ConstantVariable.create( cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e48a4881015..18193ddb3f0 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1464,6 +1464,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): ): # constant fold functions need to be guarded. if self.value in constant_fold_functions_need_guards: + assert self.source is not None source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 817385ff149..71993a62434 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """TorchDynamo support for __torch_function__ tensor subclasses. This module implements support for tensor subclasses with __torch_function__ overrides. @@ -31,7 +29,8 @@ import contextlib import functools import inspect import operator -from typing import TYPE_CHECKING +from types import TracebackType +from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree @@ -125,34 +124,134 @@ un_ops = [ banned_attrs = [ - fn.__self__.__name__ + fn.__self__.__name__ # type: ignore[attr-defined] for fn in get_default_nowrap_functions() if is_tensor_base_attr_getter(fn) ] @functools.cache -def get_prev_stack_var_name(): +def get_prev_stack_var_name() -> str: from ..bytecode_transformation import unique_id return unique_id("___prev_torch_function_mode_stack") +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool: + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the function across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") is TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") is TorchFunctionMode.__exit__ + ) + + def __init__( + self, + value: Optional[TorchFunctionMode], + source: Optional[Source] = None, + **kwargs: Any, + ): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source # type: ignore[assignment] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # This shouldn't be called unless we have a source + assert self.source + self.source.reconstruct(codegen) + + def module_name(self) -> str: + return self.value.__module__ + + def fn_name(self) -> str: + return type(self.value).__name__ + + def python_type(self) -> type: + return type(self.value) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + return call_torch_function( + tx, + get_torch_function_fn(tx, self), # type: ignore[arg-type] + fn, + types, + args, + kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return False + + # Used to clear/restore the python torch function mode stack and temporarily restore it as needed class TorchFunctionModeStackStateManager: - def __init__(self): - self.stack = [] + def __init__(self) -> None: + self.stack: list[Any] = [] - def __enter__(self): + def __enter__(self) -> None: self.stack = torch.overrides._get_current_function_mode_stack() clear_torch_function_mode_stack() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: set_torch_function_mode_stack(self.stack) self.stack = [] @contextlib.contextmanager - def temp_restore_stack(self): + def temp_restore_stack(self) -> Generator[None, None, None]: prev = torch.overrides._get_current_function_mode_stack() set_torch_function_mode_stack(self.stack) try: @@ -165,7 +264,7 @@ torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() class SymbolicTorchFunctionState: - def __init__(self, py_stack): + def __init__(self, py_stack: Iterable[Any]) -> None: # This is annoyingly complicated because of how the torch function subclass + mode C API was designed # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass # These are their definitions: @@ -199,32 +298,41 @@ class SymbolicTorchFunctionState: for i, val in enumerate(py_stack): self.mode_stack.append( - LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type] ) - def in_torch_function_mode(self): + def in_torch_function_mode(self) -> bool: return len(self.mode_stack) > 0 - def pop_torch_function_mode(self): + def pop_torch_function_mode(self) -> TorchFunctionModeVariable: return self.mode_stack.pop() - def push_torch_function_mode(self, mode_var): + def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None: self.mode_stack.append(mode_var) - def call_torch_function_mode(self, tx, fn, types, args, kwargs): + def call_torch_function_mode( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: with self._pop_mode_for_inlining() as cur_mode: return cur_mode.call_torch_function(tx, fn, types, args, kwargs) @contextlib.contextmanager - def _pop_mode_for_inlining(self): + def _pop_mode_for_inlining( + self, + ) -> Generator[TorchFunctionModeVariable, None, None]: old_mode = self.cur_mode - self.cur_mode = self.pop_torch_function_mode() + self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment] try: - yield self.cur_mode + yield self.cur_mode # type: ignore[misc] finally: mode = self.cur_mode self.cur_mode = old_mode - self.push_torch_function_mode(mode) + self.push_torch_function_mode(mode) # type: ignore[arg-type] class TorchFunctionModeStackVariable(VariableTracker): @@ -244,16 +352,20 @@ class TorchFunctionModeStackVariable(VariableTracker): # each of the indices of other modes should be shifted left by 1 (-1) offset = 0 - def __init__(self, source, symbolic_stack): + def __init__( + self, + source: Source, + symbolic_stack: collections.deque[TorchFunctionModeVariable], + ) -> None: self.source = source self.symbolic_stack = symbolic_stack @classmethod - def reset(cls): + def reset(cls) -> None: cls.offset = 0 @classmethod - def register_mutation(cls, tx: "InstructionTranslator"): + def register_mutation(cls, tx: "InstructionTranslator") -> None: if cls.stack_value_singleton not in tx.output.side_effects: var = cls( source=Source(), @@ -263,7 +375,7 @@ class TorchFunctionModeStackVariable(VariableTracker): tx.output.side_effects.mutation(var) @classmethod - def register_device_context_insertion(cls, tx: "InstructionTranslator"): + def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None: stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return @@ -277,109 +389,28 @@ class TorchFunctionModeStackVariable(VariableTracker): ) @classmethod - def clear_default_device(cls, tx: "InstructionTranslator"): + def clear_default_device(cls, tx: "InstructionTranslator") -> None: stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @staticmethod - def is_device_context(var): + def is_device_context(var: TorchFunctionModeVariable) -> bool: return isinstance(var.value, DeviceContext) or var.value is None @classmethod - def get_mode_index(cls, ind): + def get_mode_index(cls, ind: int) -> int: return ind + cls.offset -class TorchFunctionModeVariable(GenericContextWrappingVariable): - @staticmethod - def is_supported_torch_function_mode(ty): - # Supported in this sense means we can support graph breaks under the - # context. - # We are able to trace custom modes but if there are graph breaks under them - # and they have a custom __enter__/__exit__ we don't handle this for the - # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the function across two frames instead of one - # Today we support the enter/exit of the default TorchFunctionMode as well as - # DeviceContext (which is used for set_default_device) - return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( - not class_has_getattribute(ty) - and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ - and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ - ) - - def __init__(self, value, source=None, **kwargs): - if value is not None: - super().__init__(value, **kwargs) - self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code - self.source = source - - def reconstruct(self, codegen: "PyCodegen"): - # This shouldn't be called unless we have a source - assert self.source - self.source.reconstruct(codegen) - - def module_name(self): - return self.value.__module__ - - def fn_name(self): - return type(self.value).__name__ - - def python_type(self): - return type(self.value) - - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): - return call_torch_function( - tx, - get_torch_function_fn(tx, self), - fn, - types, - args, - kwargs, - ) - - def enter(self, tx): - from .torch import TorchInGraphFunctionVariable - - if isinstance(self.value, NoEnterTorchFunctionMode): - return ConstantVariable.create(None) - - TorchInGraphFunctionVariable( - torch._C._push_on_torch_function_stack - ).call_function(tx, [self], {}) - return ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - from .torch import TorchInGraphFunctionVariable - - TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( - tx, [], {} - ) - return ConstantVariable.create(None) - - def reconstruct_type(self, codegen: "PyCodegen"): - ty = NoEnterTorchFunctionMode - codegen( - AttrSource( - codegen.tx.import_source(ty.__module__), - ty.__name__, - ) - ) - - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return False - - -def _get_all_args(args, kwargs): +def _get_all_args( + args: Iterable[Any], kwargs: dict[str, Any] +) -> Iterable[VariableTracker]: return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) -def _flatten_vts(vts): +def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]: from collections import deque from .dicts import ConstDictVariable @@ -391,7 +422,7 @@ def _flatten_vts(vts): while vts: vt = vts.popleft() - if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined] vt.realize() if vt.is_realized(): @@ -407,21 +438,28 @@ def _flatten_vts(vts): return output -def _get_subclass_type(var): +def _get_subclass_type(var: VariableTracker) -> type: assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) return var.python_type() -def _get_subclass_type_var(tx: "InstructionTranslator", var): - assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) +def _get_subclass_type_var( + tx: "InstructionTranslator", var: VariableTracker +) -> VariableTracker: if isinstance(var, TensorWithTFOverrideVariable): return var.class_type_var(tx) elif isinstance(var, UserDefinedObjectVariable): source = var.source and TypeSource(var.source) return VariableTracker.build(tx, var.python_type(), source) + else: + raise AssertionError(f"Unexpected type {type(var)}") -def _is_attr_overridden(tx: "InstructionTranslator", var, name): +def _is_attr_overridden( + tx: "InstructionTranslator", var: VariableTracker, name: str +) -> bool: + if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)): + return False import torch overridden = False @@ -434,7 +472,14 @@ def _is_attr_overridden(tx: "InstructionTranslator", var, name): return overridden -def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): +def call_torch_function( + tx: "InstructionTranslator", + torch_function_var: VariableTracker, + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: # This emulates calling __torch_function__, which has a signature # def __torch_function__(cls, func, types, args=(), kwargs=None): # @@ -451,7 +496,9 @@ def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): return torch_function_var.call_function(tx, tf_args, {}) -def get_torch_function_fn(tx: "InstructionTranslator", vt): +def get_torch_function_fn( + tx: "InstructionTranslator", vt: VariableTracker +) -> VariableTracker: # The underlying function could be a classmethod, staticmethod, regular # function or a function with C-implementation. It doesn't matter as long as # they satisfy the calling convention in `call_torch_function`. @@ -462,7 +509,9 @@ def get_torch_function_fn(tx: "InstructionTranslator", vt): return func_vt -def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): +def can_dispatch_torch_function( + tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any] +) -> bool: has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) @@ -472,7 +521,12 @@ def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): ) -def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): +def dispatch_torch_function( + tx: "InstructionTranslator", + fn: VariableTracker, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" all_args = _get_all_args(args, kwargs) @@ -518,7 +572,13 @@ class TensorWithTFOverrideVariable(TensorVariable): """ @classmethod - def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): + def from_tensor_var( + cls, + tx: "InstructionTranslator", + tensor_var: VariableTracker, + class_type: type, + cls_source: Source, + ) -> "TensorWithTFOverrideVariable": # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. import torch @@ -533,7 +593,7 @@ class TensorWithTFOverrideVariable(TensorVariable): var.install_global(tx) return var - def install_global(self, tx): + def install_global(self, tx: "InstructionTranslator") -> None: # stash the subclass type to rewrap an output tensor if needed # this is needed because the actual type needs to be available # each time the compiled artifact is run and outputs a wrapped tensor. @@ -543,20 +603,20 @@ class TensorWithTFOverrideVariable(TensorVariable): self.global_mangled_class_name(tx), self.class_type ) - def python_type(self): + def python_type(self) -> type: return self.class_type - def class_type_var(self, tx): + def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker: return TensorSubclassVariable( self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) ) - def global_mangled_class_name(self, tx): + def global_mangled_class_name(self, tx: "InstructionTranslator") -> str: return get_safe_global_name( tx, f"__subclass_{self.class_type.__name__}", self.class_type ) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: # [Note: __torch_function__] We currently only support attributes that are defined on # base tensors, custom attribute accesses will graph break. import torch @@ -581,7 +641,8 @@ class TensorWithTFOverrideVariable(TensorVariable): and not attr_is_overridden and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) ): - args, kwargs = [self], {} + args = [self] + kwargs: dict[Any, Any] = {} if can_dispatch_torch_function(tx, args, kwargs): get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) @@ -636,7 +697,14 @@ class TensorWithTFOverrideVariable(TensorVariable): return super().var_getattr(tx, name) - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: # NOTE this assumes `__torch_function__` isn't modified during tracing. if not hasattr(self, "torch_function_fn"): self.torch_function_fn = get_torch_function_fn(tx, self) @@ -652,8 +720,8 @@ class TensorWithTFOverrideVariable(TensorVariable): def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker":