[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)

Provides type coverage to ~3000 LOC and 200 methods in  `torch/_dynamo/variables/`

This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)

### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42, https://github.com/Skylion007
This commit is contained in:
Lucas Kabela 2025-10-31 20:42:23 +00:00 committed by PyTorch MergeBot
parent 8209a0506b
commit 4a7bc1d522
15 changed files with 515 additions and 375 deletions

View File

@ -153,7 +153,7 @@ class PyCodegen:
self.clear_tos() self.clear_tos()
def __call__( def __call__(
self, value: Union[VariableTracker, Source], allow_cache: bool = True self, value: Union[VariableTracker, Source, None], allow_cache: bool = True
) -> None: ) -> None:
""" """
Generate code such that top-of-stack (TOS) is set to value. Generate code such that top-of-stack (TOS) is set to value.
@ -188,7 +188,7 @@ class PyCodegen:
value to handle aliasing (check side_effects.py and search for value to handle aliasing (check side_effects.py and search for
allow_cache=False). allow_cache=False).
b) If value.source is None, this is not allowed. TODO - assert this. b) If value.source is None, this is not allowed
Notable effects: Notable effects:
1. `self.top_of_stack` will be set to `value`, if we don't codegen 1. `self.top_of_stack` will be set to `value`, if we don't codegen
@ -197,6 +197,7 @@ class PyCodegen:
`top_of_stack` or cached `tempvars`, or (b). `value` has special VT `top_of_stack` or cached `tempvars`, or (b). `value` has special VT
types like `NNModuleVariable`, etc. types like `NNModuleVariable`, etc.
""" """
assert value is not None
if isinstance(value, Source): if isinstance(value, Source):
# If the source needs to be overridden, use the new one. # If the source needs to be overridden, use the new one.
source = self.overridden_sources.get(value, value) source = self.overridden_sources.get(value, value)
@ -289,7 +290,8 @@ class PyCodegen:
self.load_graph_output(graph_outputs[graph_outputs_key].index) self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append( output.append(
self.create_load_global( self.create_load_global(
value.global_mangled_class_name(self.tx), add=True value.global_mangled_class_name(self.tx), # type: ignore[arg-type]
add=True,
) )
) )
output.extend(create_call_function(2, False)) output.extend(create_call_function(2, False))

View File

@ -1303,6 +1303,7 @@ class OutputGraph(OutputGraphCommon):
# A small codegen optimization because we might have different # A small codegen optimization because we might have different
# VariableTrackers that share the same source. # VariableTrackers that share the same source.
assert x.source is not None
list_idx = x.source.index # type: ignore[attr-defined] list_idx = x.source.index # type: ignore[attr-defined]
if list_idx not in visited: if list_idx not in visited:
alias_name = self.new_var( alias_name = self.new_var(
@ -1321,6 +1322,7 @@ class OutputGraph(OutputGraphCommon):
) )
# operate on alias, handled by suffix codegen # operate on alias, handled by suffix codegen
assert x.source is not None
old_source = x.source old_source = x.source
overridden_sources[old_source] = LocalSource(visited[list_idx]) overridden_sources[old_source] = LocalSource(visited[list_idx])
@ -1864,7 +1866,6 @@ class OutputGraph(OutputGraphCommon):
and isinstance(var.value, _ExportModuleSpecTrackerDict) and isinstance(var.value, _ExportModuleSpecTrackerDict)
): ):
potential_side_effects.append(var) potential_side_effects.append(var)
side_effect_refs = [ side_effect_refs = [
_get_source_debug_name(var.source) for var in potential_side_effects _get_source_debug_name(var.source) for var in potential_side_effects
] ]

View File

@ -258,6 +258,7 @@ class SideEffects:
"Dynamo needs to fully exhaust the generator, which may cause " "Dynamo needs to fully exhaust the generator, which may cause "
"unintended variable modifications." "unintended variable modifications."
) )
assert item.mutation_type is not None
if not is_side_effect_safe(item.mutation_type): if not is_side_effect_safe(item.mutation_type):
# TODO plumb HOP information here # TODO plumb HOP information here
unimplemented_v2( unimplemented_v2(
@ -373,7 +374,7 @@ class SideEffects:
if self.is_attribute_mutation(item): if self.is_attribute_mutation(item):
return item in self.store_attr_mutations return item in self.store_attr_mutations
assert item.mutation_type is not None
return item.mutation_type.is_modified # type: ignore[attr-defined] return item.mutation_type.is_modified # type: ignore[attr-defined]
def _track_obj( def _track_obj(

View File

@ -111,11 +111,14 @@ def is_constant_source(source: Source) -> bool:
return False return False
def _get_source_debug_name(source: Source) -> str: def _get_source_debug_name(source: Optional[Source]) -> str:
try: if source is None:
return source.name()
except NotImplementedError:
return "<unknown source>" return "<unknown source>"
else:
try:
return source.name()
except NotImplementedError:
return "<unknown source>"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)

View File

@ -5201,7 +5201,7 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
): ):
if isinstance(val, ConstantVariable) and val.value is None: if isinstance(val, ConstantVariable) and val.value is None:
try: try:
val = tos.next_variable(self) val = tos.next_variable(self) # type: ignore[arg-type]
except (StopIteration, exc.ObservedUserStopIteration) as ex: except (StopIteration, exc.ObservedUserStopIteration) as ex:
# To implement SEND, we have to look at the implementation # To implement SEND, we have to look at the implementation
# when the iterator returns StopIteration. This translates to this code # when the iterator returns StopIteration. This translates to this code

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
""" """
Core variable tracking functionality for Dynamo. This module defines the fundamental Core variable tracking functionality for Dynamo. This module defines the fundamental
classes and systems used to track and manage variables during Dynamo's operation. classes and systems used to track and manage variables during Dynamo's operation.
@ -18,7 +16,10 @@ computations.
import collections import collections
from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView
from enum import Enum from enum import Enum
from typing import Any, Optional, TYPE_CHECKING from typing import Any, NoReturn, Optional, TYPE_CHECKING
from torch._guards import Guard
from torch.fx.proxy import Node
from .. import graph_break_hints, variables from .. import graph_break_hints, variables
from ..current_scope_id import current_scope_id from ..current_scope_id import current_scope_id
@ -30,7 +31,7 @@ from ..utils import cmp_name_to_op_mapping, istype
if TYPE_CHECKING: if TYPE_CHECKING:
from ..codegen import PyCodegen from ..codegen import PyCodegen
from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase from ..symbolic_convert import InstructionTranslator
class SourceType(Enum): class SourceType(Enum):
@ -115,10 +116,10 @@ class ValueMutationNew(MutationType):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(SourceType.New) super().__init__(SourceType.New)
def __hash__(self): def __hash__(self) -> int:
return id(self) return id(self)
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return self is other return self is other
@ -139,7 +140,7 @@ class ValueMutationExisting(MutationType):
# filter out which pre-existing values it needs to generate mutation for. # filter out which pre-existing values it needs to generate mutation for.
is_modified: bool is_modified: bool
def __init__(self, is_modified: bool = False): def __init__(self, is_modified: bool = False) -> None:
super().__init__(SourceType.Existing) super().__init__(SourceType.Existing)
self.is_modified = is_modified self.is_modified = is_modified
@ -150,7 +151,7 @@ class AttributeMutation(MutationType):
allows mutation on the value's attributes. allows mutation on the value's attributes.
""" """
def __init__(self, typ: SourceType): def __init__(self, typ: SourceType) -> None:
super().__init__(typ) super().__init__(typ)
@ -166,7 +167,7 @@ class AttributeMutationExisting(AttributeMutation):
be used afterwards in Python. be used afterwards in Python.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(SourceType.Existing) super().__init__(SourceType.Existing)
@ -182,16 +183,16 @@ class AttributeMutationNew(AttributeMutation):
the Python world. the Python world.
""" """
def __init__(self, cls_source: Optional[Source] = None): def __init__(self, cls_source: Optional[Source] = None) -> None:
super().__init__(SourceType.New) super().__init__(SourceType.New)
self.cls_source = cls_source self.cls_source = cls_source
def _is_top_level_scope(scope_id): def _is_top_level_scope(scope_id: int) -> bool:
return scope_id == 1 return scope_id == 1
def is_side_effect_safe(m: MutationType): def is_side_effect_safe(m: MutationType) -> bool:
scope_id = current_scope_id() scope_id = current_scope_id()
# In the top-level scope (if no HigherOrderOperators are involved), # In the top-level scope (if no HigherOrderOperators are involved),
@ -209,15 +210,15 @@ def is_side_effect_safe(m: MutationType):
class AsPythonConstantNotImplementedError(NotImplementedError): class AsPythonConstantNotImplementedError(NotImplementedError):
vt: "VariableTracker" vt: "VariableTracker"
def __init__(self, vt: "VariableTracker"): def __init__(self, vt: "VariableTracker") -> None:
super().__init__(f"{vt} is not a constant") super().__init__(f"{vt} is not a constant")
self.vt = vt self.vt = vt
class VariableTrackerMeta(type): class VariableTrackerMeta(type):
all_subclasses = [] all_subclasses: list[type] = []
def __instancecheck__(cls, instance) -> bool: def __instancecheck__(cls: type, instance: object) -> bool:
"""Make isinstance work with LazyVariableTracker""" """Make isinstance work with LazyVariableTracker"""
# This is super expensive - just having it costs over 4% of tracing # This is super expensive - just having it costs over 4% of tracing
# time! # time!
@ -227,8 +228,10 @@ class VariableTrackerMeta(type):
instance = instance.realize() instance = instance.realize()
return type.__instancecheck__(cls, instance) return type.__instancecheck__(cls, instance)
def __init__(cls, name, bases, attrs) -> None: def __init__(
super().__init__(name, bases, attrs) cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any]
) -> None:
super().__init__(name, bases, attrs) # type: ignore[misc]
VariableTrackerMeta.all_subclasses.append(cls) VariableTrackerMeta.all_subclasses.append(cls)
@ -252,7 +255,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
"user_code_variable_name", "user_code_variable_name",
} }
def clone(self, **kwargs): def clone(self, **kwargs: Any) -> "VariableTracker":
"""Shallow copy with some (optional) changes""" """Shallow copy with some (optional) changes"""
args = dict(self.__dict__) args = dict(self.__dict__)
args.update(kwargs) args.update(kwargs)
@ -295,14 +298,14 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
def debug_repr(self): def debug_repr(self) -> str:
# Intended to be overridden to provide more info # Intended to be overridden to provide more info
try: try:
return repr(self.as_python_constant()) return repr(self.as_python_constant())
except NotImplementedError: except NotImplementedError:
return repr(self) return repr(self)
def python_type(self): def python_type(self) -> type:
""" """
Abstract method to be implemented by subclasses of VariableTracker. Abstract method to be implemented by subclasses of VariableTracker.
@ -331,17 +334,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError: except NotImplementedError:
raise NotImplementedError(f"{self} has no type") from None raise NotImplementedError(f"{self} has no type") from None
def python_type_name(self): def python_type_name(self) -> str:
try: try:
return self.python_type().__name__ return self.python_type().__name__
except NotImplementedError: except NotImplementedError:
return "<unknown type>" return "<unknown type>"
def as_python_constant(self): def as_python_constant(self) -> Any:
"""For constants""" """For constants"""
raise AsPythonConstantNotImplementedError(self) raise AsPythonConstantNotImplementedError(self)
def guard_as_python_constant(self): def guard_as_python_constant(self) -> Any:
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
try: try:
return self.as_python_constant() return self.as_python_constant()
@ -353,18 +356,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
hints=[], hints=[],
) )
def is_python_constant(self): def is_python_constant(self) -> bool:
try: try:
self.as_python_constant() self.as_python_constant()
return True return True
except NotImplementedError: except NotImplementedError:
return False return False
def make_guard(self, fn): def make_guard(self, fn: Callable[..., Any]) -> Guard:
if self.source: if self.source:
return self.source.make_guard(fn) return self.source.make_guard(fn)
raise NotImplementedError raise NotImplementedError
# TODO[@lucaskabela] - change this type to `InstructionTranslatorBase`
# and cascade that (large blast radius)
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
"""getattr(self, name) returning a python constant""" """getattr(self, name) returning a python constant"""
raise NotImplementedError raise NotImplementedError
@ -381,17 +386,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
return variables.ConstantVariable.create(value, source=source) return variables.ConstantVariable.create(value, source=source)
def is_proxy(self): def is_proxy(self) -> bool:
try: try:
self.as_proxy() self.as_proxy()
return True return True
except NotImplementedError: except NotImplementedError:
return False return False
def as_proxy(self): def as_proxy(self) -> Any:
raise NotImplementedError(str(self)) raise NotImplementedError(str(self))
def maybe_fx_node(self): def maybe_fx_node(self) -> Optional[Node]:
try: try:
proxy = self.as_proxy() proxy = self.as_proxy()
import torch.fx import torch.fx
@ -402,13 +407,13 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError: except NotImplementedError:
return None return None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
raise NotImplementedError raise NotImplementedError
def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
raise NotImplementedError raise NotImplementedError
def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]: def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]:
# like unpack_var_sequence, but should only be used when it is # like unpack_var_sequence, but should only be used when it is
# safe to eagerly (vs. lazily) unpack this variable. # safe to eagerly (vs. lazily) unpack this variable.
# e.g. map(f, x) is normally evaluated lazily but sometimes # e.g. map(f, x) is normally evaluated lazily but sometimes
@ -417,7 +422,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
# it should only be called once. # it should only be called once.
return self.unpack_var_sequence(tx) return self.unpack_var_sequence(tx)
def has_unpack_var_sequence(self, tx) -> bool: def has_unpack_var_sequence(self, tx: Any) -> bool:
try: try:
self.unpack_var_sequence(tx) self.unpack_var_sequence(tx)
return True return True
@ -425,13 +430,15 @@ class VariableTracker(metaclass=VariableTrackerMeta):
return False return False
# NB: don't call force_unpack_var_sequence, especially if it mutates! # NB: don't call force_unpack_var_sequence, especially if it mutates!
def has_force_unpack_var_sequence(self, tx) -> bool: def has_force_unpack_var_sequence(self, tx: Any) -> bool:
return self.has_unpack_var_sequence(tx) return self.has_unpack_var_sequence(tx)
# Forces unpacking the var sequence while also applying a function to each element. # Forces unpacking the var sequence while also applying a function to each element.
# Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence).
# INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True!
def force_apply_to_var_sequence(self, tx, fn) -> None: def force_apply_to_var_sequence(
self, tx: Any, fn: Callable[["VariableTracker"], Any]
) -> None:
assert self.has_force_unpack_var_sequence(tx) assert self.has_force_unpack_var_sequence(tx)
for v in self.unpack_var_sequence(tx): for v in self.unpack_var_sequence(tx):
fn(v) fn(v)
@ -444,9 +451,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
hints=[], hints=[],
) )
def call_obj_hasattr( def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker":
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
unimplemented_v2( unimplemented_v2(
gb_type="Unsupported hasattr call", gb_type="Unsupported hasattr call",
context=f"call_obj_hasattr {self} {name}", context=f"call_obj_hasattr {self} {name}",
@ -459,9 +464,9 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def call_function( def call_function(
self, self,
tx: "InstructionTranslator", tx: Any,
args: Sequence["VariableTracker"], args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker": ) -> "VariableTracker":
unimplemented_v2( unimplemented_v2(
gb_type="Unsupported function call", gb_type="Unsupported function call",
@ -475,10 +480,10 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def call_method( def call_method(
self, self,
tx, tx: Any,
name, name: str,
args: "list[VariableTracker]", args: list["VariableTracker"],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker": ) -> "VariableTracker":
if name == "__len__" and self.has_unpack_var_sequence(tx): if name == "__len__" and self.has_unpack_var_sequence(tx):
assert not (args or kwargs) assert not (args or kwargs)
@ -562,7 +567,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
hints=hints, hints=hints,
) )
def set_name_hint(self, name): def set_name_hint(self, name: str) -> None:
pass pass
def realize(self) -> "VariableTracker": def realize(self) -> "VariableTracker":
@ -573,11 +578,11 @@ class VariableTracker(metaclass=VariableTrackerMeta):
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists""" """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
return self return self
def is_realized(self): def is_realized(self) -> bool:
"""Used by LazyVariableTracker to indicate an unrealized node""" """Used by LazyVariableTracker to indicate an unrealized node"""
return True return True
def next_variable(self, tx): def next_variable(self, tx: Any) -> "VariableTracker":
unimplemented_v2( unimplemented_v2(
gb_type="Unsupported next() call", gb_type="Unsupported next() call",
context=f"next({self})", context=f"next({self})",
@ -585,20 +590,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
hints=[*graph_break_hints.USER_ERROR], hints=[*graph_break_hints.USER_ERROR],
) )
def is_strict_mode(self, tx): def is_strict_mode(self, tx: Any) -> bool:
return tx.strict_checks_fn and tx.strict_checks_fn(self) return bool(tx.strict_checks_fn and tx.strict_checks_fn(self))
def is_mutable(self): def is_mutable(self) -> bool:
"""Whether Dynamo allows mutation on this variable.""" """Whether Dynamo allows mutation on this variable."""
return not self.is_immutable() return not self.is_immutable()
def is_immutable(self): def is_immutable(self) -> bool:
"""Whether Dynamo bans mutation on this variable.""" """Whether Dynamo bans mutation on this variable."""
return self.mutation_type is None return self.mutation_type is None
@staticmethod @staticmethod
def build( def build(
tx: "InstructionTranslatorBase", tx: Any,
value: Any, value: Any,
source: Optional[Source] = None, source: Optional[Source] = None,
) -> Any: ) -> Any:
@ -611,8 +616,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def __init__( def __init__(
self, self,
*, *,
source: Source = None, source: Optional[Source] = None,
mutation_type: MutationType = None, mutation_type: Optional[MutationType] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.source = source self.source = source
@ -636,12 +641,12 @@ class VariableTracker(metaclass=VariableTrackerMeta):
assert source is not None assert source is not None
def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None: def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn:
msg = variables.ConstantVariable.create(msg_str) msg = variables.ConstantVariable.create(msg_str)
raise_observed_exception(TypeError, tx, args=[msg]) raise_observed_exception(TypeError, tx, args=[msg])
def typestr(*objs): def typestr(*objs: object) -> str:
if len(objs) == 1: if len(objs) == 1:
(obj,) = objs (obj,) = objs
if isinstance(obj, VariableTracker): if isinstance(obj, VariableTracker):

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
""" """
Constant and enum variable tracking in Dynamo. Constant and enum variable tracking in Dynamo.
@ -8,8 +6,9 @@ values during compilation, ensuring proper handling of Python literals and
maintaining type safety through the compilation process. maintaining type safety through the compilation process.
""" """
import enum
import operator import operator
from typing import TYPE_CHECKING from typing import Any, Literal, Optional, TYPE_CHECKING, Union
import torch import torch
from torch._dynamo.source import AttrSource, GetItemSource from torch._dynamo.source import AttrSource, GetItemSource
@ -40,7 +39,7 @@ class ConstantVariable(VariableTracker):
""" """
@staticmethod @staticmethod
def create(value, **kwargs) -> VariableTracker: def create(value: Any, **kwargs: Any) -> VariableTracker:
""" """
Create a `ConstantVariable` based on the given value, and supports Create a `ConstantVariable` based on the given value, and supports
automatic routing for collection types like `tuple` (in which case we'd automatic routing for collection types like `tuple` (in which case we'd
@ -76,7 +75,7 @@ class ConstantVariable(VariableTracker):
return ConstantVariable(value, **kwargs) return ConstantVariable(value, **kwargs)
def __init__(self, value, **kwargs) -> None: def __init__(self, value: Any, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
assert ConstantVariable.is_base_literal(value), f""" assert ConstantVariable.is_base_literal(value), f"""
Cannot construct `ConstantVariable` for value of type {type(value)}. Cannot construct `ConstantVariable` for value of type {type(value)}.
@ -92,48 +91,52 @@ its type to `common_constant_types`.
else: else:
self.value = value self.value = value
def as_proxy(self): def as_proxy(self) -> Any:
return self.value return self.value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def is_python_constant(self): def is_python_constant(self) -> Literal[True]:
return True return True
@property @property
def items(self): def items(self) -> list[VariableTracker]:
""" """
Need this when adding a BaseListVariable and a ConstantVariable together. Need this when adding a BaseListVariable and a ConstantVariable together.
Happens in detectron2. Happens in detectron2.
""" """
return self.unpack_var_sequence(tx=None) return self.unpack_var_sequence(tx=None)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
return ConstantVariable.create( return ConstantVariable.create(
self.value[arg.as_python_constant()], self.value[arg.as_python_constant()],
) )
@staticmethod @staticmethod
def is_base_literal(obj): def is_base_literal(obj: object) -> bool:
return type(obj) in common_constant_types return type(obj) in common_constant_types
@staticmethod @staticmethod
def is_literal(obj): def is_literal(obj: object) -> bool:
if type(obj) in (list, tuple, set, frozenset, torch.Size): if type(obj) in (list, tuple, set, frozenset, torch.Size):
return all(ConstantVariable.is_literal(x) for x in obj) return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined]
return ConstantVariable.is_base_literal(obj) return ConstantVariable.is_base_literal(obj)
def unpack_var_sequence(self, tx): def unpack_var_sequence(
self, tx: Optional["InstructionTranslator"]
) -> list[VariableTracker]:
try: try:
return [ConstantVariable.create(x) for x in self.as_python_constant()] return [ConstantVariable.create(x) for x in self.as_python_constant()]
except TypeError as e: except TypeError as e:
raise NotImplementedError from e raise NotImplementedError from e
def const_getattr(self, tx: "InstructionTranslator", name): def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if not hasattr(self.value, name): if not hasattr(self.value, name):
raise_observed_exception(AttributeError, tx, args=[name]) raise_observed_exception(AttributeError, tx, args=[name])
member = getattr(self.value, name) member = getattr(self.value, name)
@ -144,10 +147,10 @@ its type to `common_constant_types`.
def call_method( def call_method(
self, self,
tx: "InstructionTranslator", tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: list[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> VariableTracker:
from .tensor import SymNodeVariable from .tensor import SymNodeVariable
if name == "format" and istype(self.value, str): if name == "format" and istype(self.value, str):
@ -262,7 +265,7 @@ its type to `common_constant_types`.
def call_obj_hasattr( def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str self, tx: "InstructionTranslator", name: str
) -> "VariableTracker": ) -> VariableTracker:
result = hasattr(self.value, name) result = hasattr(self.value, name)
return variables.ConstantVariable.create(result) return variables.ConstantVariable.create(result)
@ -274,12 +277,14 @@ class EnumVariable(VariableTracker):
both standard Enum and IntEnum with proper value tracking and comparison. both standard Enum and IntEnum with proper value tracking and comparison.
""" """
def __init__(self, value, **kwargs) -> None: def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.value = value self.value = value
@classmethod @classmethod
def create(cls, cls_type, value_vt, options): def create(
cls, cls_type: Any, value_vt: VariableTracker, options: Any
) -> "EnumVariable":
if isinstance(value_vt, variables.ConstantVariable): if isinstance(value_vt, variables.ConstantVariable):
for member in list(cls_type): for member in list(cls_type):
if member.value == value_vt.as_python_constant(): if member.value == value_vt.as_python_constant():
@ -293,7 +298,7 @@ class EnumVariable(VariableTracker):
hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE],
) )
def as_proxy(self): def as_proxy(self) -> Union[enum.Enum, int]:
if isinstance(self.value, int): if isinstance(self.value, int):
return int(self.value) # convert IntEnum to a normal int return int(self.value) # convert IntEnum to a normal int
return self.value return self.value
@ -301,10 +306,10 @@ class EnumVariable(VariableTracker):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"EnumVariable({type(self.value)})" return f"EnumVariable({type(self.value)})"
def as_python_constant(self): def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]:
return self.value return self.value
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if not hasattr(self.value, name): if not hasattr(self.value, name):
raise NotImplementedError raise NotImplementedError
if name in cmp_name_to_op_mapping: if name in cmp_name_to_op_mapping:

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
""" """
Distributed computing variable tracking classes for PyTorch Dynamo. Distributed computing variable tracking classes for PyTorch Dynamo.
@ -22,7 +20,7 @@ checks and proper tracking of distributed state and operations across processes.
import functools import functools
import inspect import inspect
from typing import TYPE_CHECKING from typing import Any, Sequence, TYPE_CHECKING
import torch import torch
from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental._backward_state import BackwardState
@ -40,6 +38,7 @@ from .constant import ConstantVariable, EnumVariable
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator from torch._dynamo.symbolic_convert import InstructionTranslator
@ -54,7 +53,7 @@ class DistributedVariable(VariableTracker):
and hold the tracking value for the corresponding distributed object. and hold the tracking value for the corresponding distributed object.
""" """
def __init__(self, value, **kwargs) -> None: def __init__(self, value: Any, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
unimplemented_v2( unimplemented_v2(
@ -67,16 +66,16 @@ class DistributedVariable(VariableTracker):
) )
self.value = value self.value = value
def python_type(self): def python_type(self) -> type:
return type(self.value) return type(self.value)
@staticmethod @staticmethod
def is_available(): def is_available() -> bool:
# check if the distributed package is available or not # check if the distributed package is available or not
return torch.distributed.is_available() return torch.distributed.is_available()
def is_from_local(value): def is_from_local(value: object) -> bool:
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
@ -84,7 +83,7 @@ def is_from_local(value):
return inspect.isfunction(value) and value is DTensor.from_local return inspect.isfunction(value) and value is DTensor.from_local
def is_constant_pg_functions(value): def is_constant_pg_functions(value: object) -> bool:
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
@ -114,7 +113,7 @@ class WorldMetaClassVariable(DistributedVariable):
""" """
@classmethod @classmethod
def is_group_member_type(cls, value): def is_group_member_type(cls, value: object) -> bool:
if not cls.is_available(): if not cls.is_available():
return False return False
@ -124,10 +123,12 @@ class WorldMetaClassVariable(DistributedVariable):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "WORLD": if name == "WORLD":
assert self.source
source = AttrSource(base=self.source, member="WORLD") source = AttrSource(base=self.source, member="WORLD")
install_guard(source.make_guard(GuardBuilder.ID_MATCH)) install_guard(source.make_guard(GuardBuilder.ID_MATCH))
return ProcessGroupVariable(self.value.WORLD) return ProcessGroupVariable(self.value.WORLD)
elif name == "NON_GROUP_MEMBER": elif name == "NON_GROUP_MEMBER":
assert self.source
source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") source = AttrSource(base=self.source, member="NON_GROUP_MEMBER")
install_guard(source.make_guard(GuardBuilder.ID_MATCH)) install_guard(source.make_guard(GuardBuilder.ID_MATCH))
return EnumVariable(self.value.NON_GROUP_MEMBER) return EnumVariable(self.value.NON_GROUP_MEMBER)
@ -136,7 +137,7 @@ class WorldMetaClassVariable(DistributedVariable):
class PlacementClassVariable(DistributedVariable): class PlacementClassVariable(DistributedVariable):
@staticmethod @staticmethod
def is_placement_type(value): def is_placement_type(value: object) -> bool:
# we can't rely on importing/accessing torch distributed, it is not always built. # we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
@ -145,15 +146,15 @@ class PlacementClassVariable(DistributedVariable):
return isinstance(value, type) and issubclass(value, Placement) return isinstance(value, type) and issubclass(value, Placement)
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def call_function( def call_function(
self, self,
tx: "InstructionTranslator", tx: "InstructionTranslator",
args: "list[VariableTracker]", args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> VariableTracker:
if self.source: if self.source:
# NOTE: we don't need to track mutations to the placement class as they # NOTE: we don't need to track mutations to the placement class as they
# are supposed to be immutable. # are supposed to be immutable.
@ -168,16 +169,15 @@ class PlacementClassVariable(DistributedVariable):
class PlacementVariable(DistributedVariable): class PlacementVariable(DistributedVariable):
@staticmethod @staticmethod
def is_placement(value): def is_placement(value: object) -> bool:
# we can't rely on importing/accessing torch distributed, it is not always built. # we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
from torch.distributed.tensor.placement_types import Placement from torch.distributed.tensor.placement_types import Placement
return isinstance(value, Placement) return isinstance(value, Placement)
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
@ -187,11 +187,11 @@ class PlacementVariable(DistributedVariable):
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: Sequence[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> VariableTracker:
from . import ConstantVariable from . import ConstantVariable
# Placement types dynamo tracking only allows following methods # Placement types dynamo tracking only allows following methods
@ -228,15 +228,16 @@ class PlacementVariable(DistributedVariable):
args = [x.as_python_constant() for x in args] args = [x.as_python_constant() for x in args]
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
assert method is not None
if name == "__setattr__": if name == "__setattr__":
method(self.value, *args, **kwargs) method(self.value, *args, **kwargs)
return self return self
constant_val = method(self.value, *args, **kwargs) constant_val = method(self.value, *args, **kwargs)
return ConstantVariable.create(constant_val) return ConstantVariable.create(constant_val)
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type]
def reconstruct(self, codegen): def reconstruct(self, codegen: "PyCodegen") -> None:
# Reconstruct the Placement object by calling its constructor # Reconstruct the Placement object by calling its constructor
# e.g., Shard(0), Replicate(), Partial() # e.g., Shard(0), Replicate(), Partial()
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
@ -263,7 +264,7 @@ class PlacementVariable(DistributedVariable):
class DeviceMeshVariable(DistributedVariable): class DeviceMeshVariable(DistributedVariable):
@staticmethod @staticmethod
def is_device_mesh(value): def is_device_mesh(value: object) -> bool:
# we can't rely on importing/accessing torch distributed, it is not always built. # we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
@ -272,7 +273,7 @@ class DeviceMeshVariable(DistributedVariable):
return istype(value, DeviceMesh) return istype(value, DeviceMesh)
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
@ -289,11 +290,11 @@ class DeviceMeshVariable(DistributedVariable):
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: list[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> VariableTracker:
if name == "size": if name == "size":
const_args = [x.as_python_constant() for x in args] const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
@ -338,16 +339,16 @@ class ProcessGroupVariable(DistributedVariable):
or just graph-break whenever one of our special cases is not hit? or just graph-break whenever one of our special cases is not hit?
""" """
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: list[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> VariableTracker:
if name == "rank": if name == "rank":
return variables.ConstantVariable.create(self.value.rank()) return variables.ConstantVariable.create(self.value.rank())
if name == "size": if name == "size":
@ -357,7 +358,7 @@ class ProcessGroupVariable(DistributedVariable):
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "group_name": if name == "group_name":
return variables.ConstantVariable.create(self.value.group_name) return variables.ConstantVariable.create(self.value.group_name)
if name in ["rank", "size"]: if name in ["rank", "size"]:
@ -368,7 +369,7 @@ class ProcessGroupVariable(DistributedVariable):
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
@staticmethod @staticmethod
def is_process_group(value): def is_process_group(value: object) -> bool:
# we can't rely on importing/accessing torch distributed, it is not always built. # we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available(): if not DistributedVariable.is_available():
return False return False
@ -386,11 +387,11 @@ class BackwardHookVariable(VariableTracker):
@staticmethod @staticmethod
def create( def create(
tx, tx: "InstructionTranslator",
module: VariableTracker, module: VariableTracker,
user_hooks: VariableTracker, user_hooks: VariableTracker,
user_pre_hooks: VariableTracker, user_pre_hooks: VariableTracker,
): ) -> "BackwardHookVariable":
if not compiled_autograd.compiled_autograd_enabled: if not compiled_autograd.compiled_autograd_enabled:
unimplemented_v2( unimplemented_v2(
gb_type="Module-level backwards hooks require compiled autograd.", gb_type="Module-level backwards hooks require compiled autograd.",
@ -401,7 +402,9 @@ class BackwardHookVariable(VariableTracker):
], ],
) )
def _in_graph_bw_hooks(bw_state: BackwardState): def _in_graph_bw_hooks(
bw_state: BackwardState,
) -> torch.utils.hooks.BackwardHook:
""" """
Rather than installing the user hooks in the graph (which Rather than installing the user hooks in the graph (which
don't survive AotAutograd), we install hooks that will call don't survive AotAutograd), we install hooks that will call
@ -448,7 +451,7 @@ class BackwardHookVariable(VariableTracker):
module: VariableTracker, module: VariableTracker,
user_hooks: VariableTracker, user_hooks: VariableTracker,
user_pre_hooks: VariableTracker, user_pre_hooks: VariableTracker,
**options, **options: Any,
) -> None: ) -> None:
super().__init__(**options) super().__init__(**options)
self.proxy = proxy self.proxy = proxy
@ -456,13 +459,13 @@ class BackwardHookVariable(VariableTracker):
self.user_hooks = user_hooks self.user_hooks = user_hooks
self.user_pre_hooks = user_pre_hooks self.user_pre_hooks = user_pre_hooks
def as_proxy(self): def as_proxy(self) -> torch.fx.Proxy:
return self.proxy return self.proxy
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: list[VariableTracker], args: list[VariableTracker],
kwargs: dict[str, VariableTracker], kwargs: dict[str, VariableTracker],
) -> VariableTracker: ) -> VariableTracker:
@ -470,7 +473,9 @@ class BackwardHookVariable(VariableTracker):
return self._setup_hook(tx, name, *args, **kwargs) return self._setup_hook(tx, name, *args, **kwargs)
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args): def _setup_hook(
self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker
) -> VariableTracker:
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy
return wrap_fx_proxy( return wrap_fx_proxy(

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
""" """
This module provides iterator-related variable tracking functionality for Dynamo. This module provides iterator-related variable tracking functionality for Dynamo.
It implements variable classes for handling Python iterators and itertools functions It implements variable classes for handling Python iterators and itertools functions
@ -16,7 +14,8 @@ handling of iterator operations during code transformation and optimization.
""" """
import itertools import itertools
from typing import TYPE_CHECKING, Union from collections.abc import Callable
from typing import Any, Sequence, TYPE_CHECKING, Union
from .. import graph_break_hints, polyfills, variables from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import ( from ..bytecode_transformation import (
@ -45,20 +44,20 @@ MAX_ITERATOR_LIMIT = 100 * 1024 # 100k
class ItertoolsVariable(VariableTracker): class ItertoolsVariable(VariableTracker):
def __init__(self, value, **kwargs) -> None: def __init__(self, value: Any, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.value = value self.value = value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ItertoolsVariable({self.value})" return f"ItertoolsVariable({self.value})"
def as_python_constant(self): def as_python_constant(self) -> Any:
return self.value return self.value
def call_function( def call_function(
self, self,
tx: "InstructionTranslator", tx: "InstructionTranslator",
args: "list[VariableTracker]", args: Sequence["VariableTracker"],
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
# See also: module `torch._dynamo.polyfills.itertools` # See also: module `torch._dynamo.polyfills.itertools`
@ -111,7 +110,7 @@ class ItertoolsVariable(VariableTracker):
hints=[*graph_break_hints.USER_ERROR], hints=[*graph_break_hints.USER_ERROR],
) )
def retrieve_const_key(key): def retrieve_const_key(key: VariableTracker) -> Any:
if isinstance(key, variables.SymNodeVariable): if isinstance(key, variables.SymNodeVariable):
return key.evaluate_expr() return key.evaluate_expr()
elif isinstance(key, variables.ConstantVariable): elif isinstance(key, variables.ConstantVariable):
@ -144,14 +143,14 @@ class ItertoolsVariable(VariableTracker):
if "key" in kwargs: if "key" in kwargs:
def keyfunc(x): def keyfunc(x: VariableTracker) -> Any:
return retrieve_const_key( return retrieve_const_key(
kwargs.get("key").call_function(tx, [x], {}) kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr]
) )
else: else:
def keyfunc(x): def keyfunc(x: VariableTracker) -> Any:
return retrieve_const_key(x) return retrieve_const_key(x)
result = [] result = []
@ -219,10 +218,10 @@ class ItertoolsVariable(VariableTracker):
class IteratorVariable(VariableTracker): class IteratorVariable(VariableTracker):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
unimplemented_v2( unimplemented_v2(
gb_type="Unimplemented next() call", gb_type="Unimplemented next() call",
context=f"next({self})", context=f"next({self})",
@ -234,12 +233,16 @@ class IteratorVariable(VariableTracker):
# Normally, iterators are accessed lazily. # Normally, iterators are accessed lazily.
# Example of safe eager unpacking: list(map(f, seq)) # Example of safe eager unpacking: list(map(f, seq))
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) # Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: def force_unpack_var_sequence(
result = [] self, tx: "InstructionTranslator"
) -> list[VariableTracker]:
result: list[VariableTracker] = []
self.force_apply_to_var_sequence(tx, result.append) self.force_apply_to_var_sequence(tx, result.append)
return result return result
def force_apply_to_var_sequence(self, tx, fn) -> None: def force_apply_to_var_sequence(
self, tx: "InstructionTranslator", fn: Callable[[Any], Any]
) -> None:
while True: while True:
try: try:
fn(self.next_variable(tx)) fn(self.next_variable(tx))
@ -249,7 +252,7 @@ class IteratorVariable(VariableTracker):
# don't call force_unpack_var_sequence since it can mutate # don't call force_unpack_var_sequence since it can mutate
# IteratorVariable state! # IteratorVariable state!
def has_force_unpack_var_sequence(self, tx) -> bool: def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return True return True
def call_obj_hasattr( def call_obj_hasattr(
@ -257,12 +260,12 @@ class IteratorVariable(VariableTracker):
) -> "VariableTracker": ) -> "VariableTracker":
if name == "__iter__" or name == "__next__": if name == "__iter__" or name == "__next__":
return variables.ConstantVariable.create(True) return variables.ConstantVariable.create(True)
super().call_obj_hasattr(tx, name) return super().call_obj_hasattr(tx, name)
def call_method( def call_method(
self, self,
tx: "InstructionTranslator", tx: "InstructionTranslator",
name, name: str,
args: list[VariableTracker], args: list[VariableTracker],
kwargs: dict[str, VariableTracker], kwargs: dict[str, VariableTracker],
) -> VariableTracker: ) -> VariableTracker:
@ -287,12 +290,12 @@ class ObjectIteratorVariable(IteratorVariable):
> list(b) # empty list > list(b) # empty list
""" """
def __init__(self, obj: VariableTracker, **kwargs): def __init__(self, obj: VariableTracker, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.obj = obj self.obj = obj
self.generator_exhausted = False self.generator_exhausted = False
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
if self.generator_exhausted: if self.generator_exhausted:
raise_observed_exception(StopIteration, tx) raise_observed_exception(StopIteration, tx)
@ -306,15 +309,15 @@ class ObjectIteratorVariable(IteratorVariable):
class RepeatIteratorVariable(IteratorVariable): class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs) -> None: def __init__(self, item: VariableTracker, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.item = item self.item = item
# Repeat needs no mutation, clone self # Repeat needs no mutation, clone self
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
return self.item return self.item
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.extend_output( lambda: codegen.extend_output(
[ [
@ -328,7 +331,12 @@ class RepeatIteratorVariable(IteratorVariable):
class CountIteratorVariable(IteratorVariable): class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: def __init__(
self,
item: Union[int, VariableTracker] = 0,
step: Union[int, VariableTracker] = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
if not isinstance(item, VariableTracker): if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item) item = ConstantVariable.create(item)
@ -337,14 +345,14 @@ class CountIteratorVariable(IteratorVariable):
self.item = item self.item = item
self.step = step self.step = step
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
assert self.is_mutable() assert self.is_mutable()
old_item = self.item old_item = self.item
tx.output.side_effects.mutation(self) tx.output.side_effects.mutation(self)
self.item = self.item.call_method(tx, "__add__", [self.step], {}) self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item return old_item
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.extend_output( lambda: codegen.extend_output(
[ [
@ -373,7 +381,7 @@ class ZipVariable(IteratorVariable):
self, self,
iterables: list[VariableTracker], iterables: list[VariableTracker],
strict: bool = False, strict: bool = False,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
assert isinstance(iterables, list) assert isinstance(iterables, list)
@ -382,16 +390,18 @@ class ZipVariable(IteratorVariable):
self.index = 0 self.index = 0
self.strict = strict self.strict = strict
def python_type(self): def python_type(self) -> type[zip]: # type: ignore[type-arg]
return zip return zip
def has_unpack_var_sequence(self, tx) -> bool: def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return all( return all(
isinstance(it, list) or it.has_unpack_var_sequence(tx) isinstance(it, list) or it.has_unpack_var_sequence(tx)
for it in self.iterables for it in self.iterables
) )
def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list["VariableTracker"]:
assert self.has_unpack_var_sequence(tx) assert self.has_unpack_var_sequence(tx)
iterables = [] iterables = []
for it in self.iterables: for it in self.iterables:
@ -403,7 +413,7 @@ class ZipVariable(IteratorVariable):
zipped = zip(*iterables, **kwargs) zipped = zip(*iterables, **kwargs)
return [variables.TupleVariable(list(var)) for var in zipped] return [variables.TupleVariable(list(var)) for var in zipped]
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
assert self.is_mutable() assert self.is_mutable()
if len(self.iterables) == 0: if len(self.iterables) == 0:
@ -412,7 +422,9 @@ class ZipVariable(IteratorVariable):
old_index = self.index old_index = self.index
args = [] args = []
def get_item(it): def get_item(
it: Union[list[VariableTracker], VariableTracker],
) -> VariableTracker:
if isinstance(it, list): if isinstance(it, list):
if old_index >= len(it): if old_index >= len(it):
raise_observed_exception(StopIteration, tx) raise_observed_exception(StopIteration, tx)
@ -441,7 +453,7 @@ class ZipVariable(IteratorVariable):
raise raise
handle_observed_exception(tx) handle_observed_exception(tx)
raise UserError( raise UserError(
ValueError, ValueError, # type: ignore[arg-type]
"zip() has one argument of len differing from others", "zip() has one argument of len differing from others",
) from None ) from None
raise raise
@ -450,7 +462,7 @@ class ZipVariable(IteratorVariable):
self.index += 1 self.index += 1
return variables.TupleVariable(args) return variables.TupleVariable(args)
def reconstruct_items(self, codegen: "PyCodegen"): def reconstruct_items(self, codegen: "PyCodegen") -> None:
for it in self.iterables: for it in self.iterables:
if isinstance(it, list): if isinstance(it, list):
remaining_items = it[self.index :] remaining_items = it[self.index :]
@ -459,7 +471,7 @@ class ZipVariable(IteratorVariable):
else: else:
codegen(it) codegen(it)
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
) )
@ -483,23 +495,23 @@ class MapVariable(ZipVariable):
def __init__( def __init__(
self, self,
fn: VariableTracker, fn: VariableTracker,
iterables: list[Union[list[VariableTracker], VariableTracker]], iterables: list[VariableTracker],
**kwargs, **kwargs: Any,
) -> None: ) -> None:
super().__init__(iterables, **kwargs) super().__init__(iterables, **kwargs)
self.fn = fn self.fn = fn
def python_type(self): def python_type(self) -> type:
return map return map
def has_unpack_var_sequence(self, tx) -> bool: def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return False return False
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
args = super().next_variable(tx) args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {}) return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined]
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
) )
@ -526,23 +538,25 @@ class FilterVariable(IteratorVariable):
def __init__( def __init__(
self, self,
fn: VariableTracker, fn: VariableTracker,
iterable: Union[list[VariableTracker], VariableTracker], iterable: list[VariableTracker],
**kwargs, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.fn = fn self.fn = fn
self.iterable = iterable self.iterable = iterable
self.index = 0 self.index = 0
def python_type(self): def python_type(self) -> type:
return filter return filter
def has_unpack_var_sequence(self, tx) -> bool: def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence(
tx tx
) )
def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list["VariableTracker"]:
assert self.has_unpack_var_sequence(tx) assert self.has_unpack_var_sequence(tx)
it = None it = None
if isinstance(self.iterable, list): if isinstance(self.iterable, list):
@ -552,8 +566,8 @@ class FilterVariable(IteratorVariable):
filtered = self.fn.call_function(tx, it, {}) filtered = self.fn.call_function(tx, it, {})
return [variables.TupleVariable([filtered])] return [variables.TupleVariable([filtered])]
def next_variable(self, tx): def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
def _next(): def _next() -> VariableTracker:
old_index = self.index old_index = self.index
if isinstance(self.iterable, list): if isinstance(self.iterable, list):
if old_index >= len(self.iterable): if old_index >= len(self.iterable):
@ -576,7 +590,7 @@ class FilterVariable(IteratorVariable):
if pred_res.as_python_constant(): if pred_res.as_python_constant():
return item return item
def reconstruct_items(self, codegen: "PyCodegen"): def reconstruct_items(self, codegen: "PyCodegen") -> None:
if isinstance(self.iterable, list): if isinstance(self.iterable, list):
remaining_items = self.iterable[self.index :] remaining_items = self.iterable[self.index :]
codegen.foreach(remaining_items) codegen.foreach(remaining_items)
@ -584,7 +598,7 @@ class FilterVariable(IteratorVariable):
else: else:
codegen(self.iterable) codegen(self.iterable)
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter"))
codegen(self.fn) codegen(self.fn)
self.reconstruct_items(codegen) self.reconstruct_items(codegen)

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
""" """
This module implements variable tracking for PyTorch optimizers during Dynamo tracing. This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
@ -24,9 +22,11 @@ optimizer-specific optimizations and safety guarantees.
import logging import logging
import weakref import weakref
from typing import TYPE_CHECKING from typing import Any, Iterable, Optional, TYPE_CHECKING
import torch import torch
from torch._dynamo.variables.tensor import TensorVariable
from torch._guards import Source
from torch._logging import getArtifactLogger from torch._logging import getArtifactLogger
from torch.utils._pytree import tree_map_only from torch.utils._pytree import tree_map_only
@ -63,13 +63,14 @@ class GuardInstallException(Exception):
perf_hint_log = getArtifactLogger(__name__, "perf_hints") perf_hint_log = getArtifactLogger(__name__, "perf_hints")
def _is_static_for_cudagraphs(x): def _is_static_for_cudagraphs(x: torch.Tensor) -> bool:
from torch._inductor.cudagraph_trees import get_manager from torch._inductor.cudagraph_trees import get_manager
if x.is_cuda: if x.is_cuda:
manager = get_manager(x.device.index, False) manager = get_manager(x.device.index, False)
is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
if manager: if manager:
assert manager.current_node is not None
return ( return (
is_static_address is_static_address
or manager.current_node._is_cuda_graph_recorded_tensor(x) or manager.current_node._is_cuda_graph_recorded_tensor(x)
@ -91,26 +92,30 @@ class OptimizerVariable(UserDefinedObjectVariable):
def __init__( def __init__(
self, self,
value, value: torch.optim.Optimizer,
grad_to_source=None, grad_to_source: Optional[dict[Any, GradSource]] = None,
static_tensor_names=None, static_tensor_names: Optional[set[str]] = None,
tensor_to_source=None, tensor_to_source: Optional[dict[torch.Tensor, Source]] = None,
**kwargs, **kwargs: Any,
) -> None: ) -> None:
super().__init__(value, **kwargs) super().__init__(value, **kwargs)
self.value: torch.optim.Optimizer = value
self.grad_to_source = grad_to_source or {} self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {} self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set() self.static_tensor_names = static_tensor_names or set()
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: list[VariableTracker],
kwargs: "dict[str, VariableTracker]", kwargs: dict[str, VariableTracker],
) -> "VariableTracker": ) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer""" """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group": if name == "_init_group":
if not hasattr(self.value, "_init_group"):
# Fallback: if the optimizer does not have _init_group, trace normally
return super().call_method(tx, name, args, kwargs)
try: try:
self.graph_break_if_pending_mutation(tx) self.graph_break_if_pending_mutation(tx)
self.move_step_if_cpu() self.move_step_if_cpu()
@ -135,11 +140,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
# Note: this allows us to intercept the call in call_method # Note: this allows us to intercept the call in call_method
# in the typical case, we return a UserMethodVariable # in the typical case, we return a UserMethodVariable
# which will directly inline # which will directly inline
if name in ("_init_group", "step"): if name in ("_init_group", "step"):
assert self.source
return GetAttrVariable(self, name, source=AttrSource(self.source, name)) return GetAttrVariable(self, name, source=AttrSource(self.source, name))
if name == "param_groups": if name == "param_groups":
@ -153,7 +159,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
def graph_break_if_pending_mutation(self, tx): def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None:
# If there are pending mutations on a parameter (due to using closure) # If there are pending mutations on a parameter (due to using closure)
# then we need to graph break to allow the python version of the parameter # then we need to graph break to allow the python version of the parameter
# to update, so that running _init_group will initialize the states with # to update, so that running _init_group will initialize the states with
@ -167,12 +173,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
raise Unsupported("Pending mutation on parameter") raise Unsupported("Pending mutation on parameter")
def _set_capturable(self, tx): def _set_capturable(self, tx: "InstructionTranslator") -> None:
from . import LazyVariableTracker from . import LazyVariableTracker
# We only set capturable if params are on cuda # We only set capturable if params are on cuda
# and the state is not initialized # and the state is not initialized
def safe_to_set_capturable(group): def safe_to_set_capturable(group: dict[str, Any]) -> bool:
all_uninitialized = True all_uninitialized = True
all_gpu = True all_gpu = True
@ -199,10 +205,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
) )
param_group_vt.items[key] = ConstantVariable.create(True) param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs): def get_python_args(
self, *args: Any, **kwargs: Any
) -> tuple[list[Any], dict[str, Any]]:
"""Get python values equivalent to the variable tracker args""" """Get python values equivalent to the variable tracker args"""
def map_arg(arg): def map_arg(arg: Any) -> Any:
if isinstance(arg, ConstantVariable): if isinstance(arg, ConstantVariable):
return arg.as_python_constant() return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items: elif isinstance(arg, ListVariable) and not arg.items:
@ -227,19 +235,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
# if this is the case, move it to the GPU # if this is the case, move it to the GPU
# corresponding to the parameter # corresponding to the parameter
# in most cases this is a no-op because the state is empty # in most cases this is a no-op because the state is empty
def move_step_if_cpu(self): def move_step_if_cpu(self) -> None:
for p, state in self.value.state.items(): for p, state in self.value.state.items():
if "step" in state and state["step"].is_cpu: if "step" in state and state["step"].is_cpu:
state["step"] = state["step"].to(p.device) state["step"] = state["step"].to(p.device)
def map_sources_and_install_guards(self, tx): def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None:
from ..decorators import mark_static_address from ..decorators import mark_static_address
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
self.grad_to_source = {} self.grad_to_source = {}
self.tensor_to_source = {} self.tensor_to_source = {}
def mark_static(x): def mark_static(x: Any) -> None:
mark_static_address(x, guard=True) mark_static_address(x, guard=True)
tree_map_only(torch.Tensor, mark_static, self.value.state) tree_map_only(torch.Tensor, mark_static, self.value.state)
@ -252,12 +260,12 @@ class OptimizerVariable(UserDefinedObjectVariable):
) )
state_source = self.source and AttrSource(self.source, "state") state_source = self.source and AttrSource(self.source, "state")
state_vt = VariableTracker.build(tx, self.value.state, state_source) state_vt = VariableTracker.build(tx, self.value.state, state_source)
# We need to realize the top level state dict to populate # We need to realize the top level state dict to populate
# the guard locals # the guard locals
state_vt.realize() state_vt.realize()
assert state_source is not None
tx.output.guard_on_key_order.add(state_source) tx.output.guard_on_key_order.add(state_source)
# Populate self.grad_to_source and self.tensor_to_source so that we can # Populate self.grad_to_source and self.tensor_to_source so that we can
@ -308,14 +316,14 @@ class OptimizerVariable(UserDefinedObjectVariable):
# Note: to avoid spam logs only warn if perf hint artifact is enabled # Note: to avoid spam logs only warn if perf hint artifact is enabled
# (NB: artifacts are only enabled at the debug or warning level) # (NB: artifacts are only enabled at the debug or warning level)
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
non_static_grads = [src.name() for src in non_static_grads] non_static_grad_names = [src.name() for src in non_static_grads]
perf_hint_log.warning( perf_hint_log.warning(
( (
"Grad tensors %s will be copied during cudagraphs execution." "Grad tensors %s will be copied during cudagraphs execution."
"If using cudagraphs and the grad tensor addresses will be the same across runs," "If using cudagraphs and the grad tensor addresses will be the same across runs,"
" use torch._dynamo.decorators.mark_static_address to elide this copy.", " use torch._dynamo.decorators.mark_static_address to elide this copy.",
), ),
non_static_grads, non_static_grad_names,
) )
# We have to again iterate over the state dict to collect the # We have to again iterate over the state dict to collect the
@ -335,7 +343,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
p_state_source, ConstDictKeySource(p_state_source, inner_idx) p_state_source, ConstDictKeySource(p_state_source, inner_idx)
) )
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): def wrap_tensor(
self, tx: "InstructionTranslator", tensor_value: torch.Tensor
) -> TensorVariable:
"""Wrap state tensor in a TensorVariable""" """Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address from ..decorators import mark_static_address
@ -362,8 +372,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
return VariableTracker.build(tx, tensor_value, source) return VariableTracker.build(tx, tensor_value, source)
def update_list_args( def update_list_args(
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs self,
): tx: "InstructionTranslator",
args: Iterable[VariableTracker],
kwargs: Any,
py_args: Iterable[Any],
py_kwargs: Any,
) -> None:
"""Update the args and kwargs to the traced optimizer call""" """Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args): for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable): if isinstance(arg, ListVariable):
@ -378,13 +393,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
source = arg.source and GetItemSource(arg.source, i) source = arg.source and GetItemSource(arg.source, i)
arg.items.append(VariableTracker.build(tx, val, source)) arg.items.append(VariableTracker.build(tx, val, source))
def create_finalizer(self, tx): def create_finalizer(self, tx: "InstructionTranslator") -> None:
names_to_delete = self.static_tensor_names names_to_delete = self.static_tensor_names
value = self.value value = self.value
tc = tx.output.tracing_context tc = tx.output.tracing_context
def init_finalizer(gm): def init_finalizer(gm: torch.fx.GraphModule) -> None:
def clear_static_tensor_refs(): def clear_static_tensor_refs() -> None:
for name in names_to_delete: for name in names_to_delete:
gm._buffers.pop(name, None) gm._buffers.pop(name, None)
gm._parameters.pop(name, None) gm._parameters.pop(name, None)

View File

@ -1,6 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
""" """
This module implements variable tracking for TorchScript objects during Dynamo tracing. This module implements variable tracking for TorchScript objects during Dynamo tracing.
@ -22,8 +19,13 @@ by limiting operations to known-safe patterns and failing fast for unsafe usage.
""" """
import functools import functools
from collections.abc import Callable
from typing import Any, Iterable, TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
import torch import torch
from torch._guards import Source
from torch.fx.proxy import Proxy
from .. import graph_break_hints from .. import graph_break_hints
from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported
@ -31,10 +33,19 @@ from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable from .user_defined import UserDefinedObjectVariable
def _raise_hard_error_if_graph_break(reason): if TYPE_CHECKING:
def deco(fn): from torch._dynamo.symbolic_convert import InstructionTranslator
_P = ParamSpec("_P")
_T = TypeVar("_T")
def _raise_hard_error_if_graph_break(
reason: str,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(fn) @functools.wraps(fn)
def graph_break_as_hard_error(*args, **kwargs): def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T:
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
except Unsupported as e: except Unsupported as e:
@ -49,26 +60,26 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
_fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {}
@classmethod @classmethod
def is_matching_cls(cls, user_cls: type): def is_matching_cls(cls, user_cls: type) -> bool:
return issubclass(user_cls, torch.ScriptObject) return issubclass(user_cls, torch.ScriptObject)
@staticmethod @staticmethod
def create(proxy, value, **options): def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
return TorchScriptObjectVariable(proxy, value, **options) return TorchScriptObjectVariable(proxy, value, **options)
def __init__(self, proxy, value, source, **kwargs) -> None: def __init__(self, proxy: Proxy, value: Any, source: Source, **kwargs: Any) -> None:
super().__init__(value, **kwargs) super().__init__(value, **kwargs)
self.proxy = proxy self.proxy = proxy
self.proxy.node.meta["example_value"] = value self.proxy.node.meta["example_value"] = value
self.source = source self.source = source
def as_proxy(self): def as_proxy(self) -> Proxy:
return self.proxy return self.proxy
@_raise_hard_error_if_graph_break( @_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break." "Dynamo cannot safely trace script object due to graph break."
) )
def var_getattr(self, tx, name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
from torch._higher_order_ops.torchbind import call_torchbind from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource from ..source import AttrSource
@ -95,7 +106,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
"Use method calls instead of attribute access.", "Use method calls instead of attribute access.",
], ],
) )
assert self.source is not None
return TorchHigherOrderOperatorVariable.make( return TorchHigherOrderOperatorVariable.make(
call_torchbind, call_torchbind,
source=AttrSource(self.source, name), source=AttrSource(self.source, name),
@ -110,7 +121,13 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
@_raise_hard_error_if_graph_break( @_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break." "Dynamo cannot safely trace script object due to graph break."
) )
def call_method(self, tx, name, args, kwargs): def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> VariableTracker:
unimplemented_v2( unimplemented_v2(
gb_type="Weird method call on TorchScript object", gb_type="Weird method call on TorchScript object",
context=f"value={self.value}, method={name}", context=f"value={self.value}, method={name}",

View File

@ -1,7 +1,9 @@
# mypy: ignore-errors
from inspect import getattr_static from inspect import getattr_static
from typing import TYPE_CHECKING from typing import Any, Sequence, TYPE_CHECKING, TypeGuard
from torch._guards import Source
from torch.backends.cuda import SDPAParams
from torch.fx.proxy import Proxy
from ..bytecode_transformation import create_call_function from ..bytecode_transformation import create_call_function
from ..exc import Unsupported from ..exc import Unsupported
@ -29,9 +31,9 @@ class SDPAParamsVariable(VariableTracker):
This is a read-only container.""" This is a read-only container."""
@staticmethod @staticmethod
def create(tx: "InstructionTranslator", value, source): def create(
from torch.backends.cuda import SDPAParams tx: "InstructionTranslator", value: Any, source: Source
) -> VariableTracker:
from .torch import TorchInGraphFunctionVariable from .torch import TorchInGraphFunctionVariable
params = [ params = [
@ -40,12 +42,14 @@ class SDPAParamsVariable(VariableTracker):
] ]
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
def __init__(self, proxy, param_vars, **kwargs) -> None: def __init__(
self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any
) -> None:
self.proxy = proxy self.proxy = proxy
self.param_vars = param_vars self.param_vars = param_vars
super().__init__(**kwargs) super().__init__(**kwargs)
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.source is None assert self.source is None
assert self.param_vars is not None assert self.param_vars is not None
codegen.add_push_null( codegen.add_push_null(
@ -54,7 +58,7 @@ class SDPAParamsVariable(VariableTracker):
codegen.foreach(self.param_vars) codegen.foreach(self.param_vars)
codegen.extend_output(create_call_function(len(self.param_vars), False)) codegen.extend_output(create_call_function(len(self.param_vars), False))
def as_proxy(self): def as_proxy(self) -> Proxy:
return self.proxy return self.proxy
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
@ -80,7 +84,5 @@ class SDPAParamsVariable(VariableTracker):
return wrap_fx_proxy(tx=tx, proxy=proxy) return wrap_fx_proxy(tx=tx, proxy=proxy)
@staticmethod @staticmethod
def is_sdpa_params(value): def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]:
from torch.backends.cuda import SDPAParams
return value is SDPAParams return value is SDPAParams

View File

@ -232,6 +232,7 @@ class StreamVariable(StreamContextVariable):
return ConstantVariable.create(NotImplemented) return ConstantVariable.create(NotImplemented)
if other.source: if other.source:
assert self.source is not None
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
return ConstantVariable.create( return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]

View File

@ -1464,6 +1464,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
): ):
# constant fold functions need to be guarded. # constant fold functions need to be guarded.
if self.value in constant_fold_functions_need_guards: if self.value in constant_fold_functions_need_guards:
assert self.source is not None
source = CallFunctionNoArgsSource(self.source) source = CallFunctionNoArgsSource(self.source)
install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH))
# constant fold # constant fold

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""TorchDynamo support for __torch_function__ tensor subclasses. """TorchDynamo support for __torch_function__ tensor subclasses.
This module implements support for tensor subclasses with __torch_function__ overrides. This module implements support for tensor subclasses with __torch_function__ overrides.
@ -31,7 +29,8 @@ import contextlib
import functools import functools
import inspect import inspect
import operator import operator
from typing import TYPE_CHECKING from types import TracebackType
from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING
import torch._C import torch._C
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -125,34 +124,134 @@ un_ops = [
banned_attrs = [ banned_attrs = [
fn.__self__.__name__ fn.__self__.__name__ # type: ignore[attr-defined]
for fn in get_default_nowrap_functions() for fn in get_default_nowrap_functions()
if is_tensor_base_attr_getter(fn) if is_tensor_base_attr_getter(fn)
] ]
@functools.cache @functools.cache
def get_prev_stack_var_name(): def get_prev_stack_var_name() -> str:
from ..bytecode_transformation import unique_id from ..bytecode_transformation import unique_id
return unique_id("___prev_torch_function_mode_stack") return unique_id("___prev_torch_function_mode_stack")
class TorchFunctionModeVariable(GenericContextWrappingVariable):
@staticmethod
def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool:
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the function across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") is TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") is TorchFunctionMode.__exit__
)
def __init__(
self,
value: Optional[TorchFunctionMode],
source: Optional[Source] = None,
**kwargs: Any,
):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source # type: ignore[assignment]
def reconstruct(self, codegen: "PyCodegen") -> None:
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def module_name(self) -> str:
return self.value.__module__
def fn_name(self) -> str:
return type(self.value).__name__
def python_type(self) -> type:
return type(self.value)
def call_torch_function(
self,
tx: "InstructionTranslator",
fn: VariableTracker,
types: TupleVariable,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> VariableTracker:
return call_torch_function(
tx,
get_torch_function_fn(tx, self), # type: ignore[arg-type]
fn,
types,
args,
kwargs,
)
def enter(self, tx: "InstructionTranslator") -> VariableTracker:
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker:
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen: "PyCodegen") -> None:
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self) -> bool:
return True
def exit_on_graph_break(self) -> bool:
return False
# Used to clear/restore the python torch function mode stack and temporarily restore it as needed # Used to clear/restore the python torch function mode stack and temporarily restore it as needed
class TorchFunctionModeStackStateManager: class TorchFunctionModeStackStateManager:
def __init__(self): def __init__(self) -> None:
self.stack = [] self.stack: list[Any] = []
def __enter__(self): def __enter__(self) -> None:
self.stack = torch.overrides._get_current_function_mode_stack() self.stack = torch.overrides._get_current_function_mode_stack()
clear_torch_function_mode_stack() clear_torch_function_mode_stack()
def __exit__(self, exc_type, exc_value, traceback): def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
set_torch_function_mode_stack(self.stack) set_torch_function_mode_stack(self.stack)
self.stack = [] self.stack = []
@contextlib.contextmanager @contextlib.contextmanager
def temp_restore_stack(self): def temp_restore_stack(self) -> Generator[None, None, None]:
prev = torch.overrides._get_current_function_mode_stack() prev = torch.overrides._get_current_function_mode_stack()
set_torch_function_mode_stack(self.stack) set_torch_function_mode_stack(self.stack)
try: try:
@ -165,7 +264,7 @@ torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager()
class SymbolicTorchFunctionState: class SymbolicTorchFunctionState:
def __init__(self, py_stack): def __init__(self, py_stack: Iterable[Any]) -> None:
# This is annoyingly complicated because of how the torch function subclass + mode C API was designed # This is annoyingly complicated because of how the torch function subclass + mode C API was designed
# There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass
# These are their definitions: # These are their definitions:
@ -199,32 +298,41 @@ class SymbolicTorchFunctionState:
for i, val in enumerate(py_stack): for i, val in enumerate(py_stack):
self.mode_stack.append( self.mode_stack.append(
LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type]
) )
def in_torch_function_mode(self): def in_torch_function_mode(self) -> bool:
return len(self.mode_stack) > 0 return len(self.mode_stack) > 0
def pop_torch_function_mode(self): def pop_torch_function_mode(self) -> TorchFunctionModeVariable:
return self.mode_stack.pop() return self.mode_stack.pop()
def push_torch_function_mode(self, mode_var): def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None:
self.mode_stack.append(mode_var) self.mode_stack.append(mode_var)
def call_torch_function_mode(self, tx, fn, types, args, kwargs): def call_torch_function_mode(
self,
tx: "InstructionTranslator",
fn: VariableTracker,
types: TupleVariable,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> Any:
with self._pop_mode_for_inlining() as cur_mode: with self._pop_mode_for_inlining() as cur_mode:
return cur_mode.call_torch_function(tx, fn, types, args, kwargs) return cur_mode.call_torch_function(tx, fn, types, args, kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def _pop_mode_for_inlining(self): def _pop_mode_for_inlining(
self,
) -> Generator[TorchFunctionModeVariable, None, None]:
old_mode = self.cur_mode old_mode = self.cur_mode
self.cur_mode = self.pop_torch_function_mode() self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment]
try: try:
yield self.cur_mode yield self.cur_mode # type: ignore[misc]
finally: finally:
mode = self.cur_mode mode = self.cur_mode
self.cur_mode = old_mode self.cur_mode = old_mode
self.push_torch_function_mode(mode) self.push_torch_function_mode(mode) # type: ignore[arg-type]
class TorchFunctionModeStackVariable(VariableTracker): class TorchFunctionModeStackVariable(VariableTracker):
@ -244,16 +352,20 @@ class TorchFunctionModeStackVariable(VariableTracker):
# each of the indices of other modes should be shifted left by 1 (-1) # each of the indices of other modes should be shifted left by 1 (-1)
offset = 0 offset = 0
def __init__(self, source, symbolic_stack): def __init__(
self,
source: Source,
symbolic_stack: collections.deque[TorchFunctionModeVariable],
) -> None:
self.source = source self.source = source
self.symbolic_stack = symbolic_stack self.symbolic_stack = symbolic_stack
@classmethod @classmethod
def reset(cls): def reset(cls) -> None:
cls.offset = 0 cls.offset = 0
@classmethod @classmethod
def register_mutation(cls, tx: "InstructionTranslator"): def register_mutation(cls, tx: "InstructionTranslator") -> None:
if cls.stack_value_singleton not in tx.output.side_effects: if cls.stack_value_singleton not in tx.output.side_effects:
var = cls( var = cls(
source=Source(), source=Source(),
@ -263,7 +375,7 @@ class TorchFunctionModeStackVariable(VariableTracker):
tx.output.side_effects.mutation(var) tx.output.side_effects.mutation(var)
@classmethod @classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"): def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None:
stack = tx.symbolic_torch_function_state.mode_stack stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]): if stack and cls.is_device_context(stack[0]):
return return
@ -277,109 +389,28 @@ class TorchFunctionModeStackVariable(VariableTracker):
) )
@classmethod @classmethod
def clear_default_device(cls, tx: "InstructionTranslator"): def clear_default_device(cls, tx: "InstructionTranslator") -> None:
stack = tx.symbolic_torch_function_state.mode_stack stack = tx.symbolic_torch_function_state.mode_stack
if stack and cls.is_device_context(stack[0]): if stack and cls.is_device_context(stack[0]):
stack.popleft() stack.popleft()
cls.offset -= 1 cls.offset -= 1
@staticmethod @staticmethod
def is_device_context(var): def is_device_context(var: TorchFunctionModeVariable) -> bool:
return isinstance(var.value, DeviceContext) or var.value is None return isinstance(var.value, DeviceContext) or var.value is None
@classmethod @classmethod
def get_mode_index(cls, ind): def get_mode_index(cls, ind: int) -> int:
return ind + cls.offset return ind + cls.offset
class TorchFunctionModeVariable(GenericContextWrappingVariable): def _get_all_args(
@staticmethod args: Iterable[Any], kwargs: dict[str, Any]
def is_supported_torch_function_mode(ty): ) -> Iterable[VariableTracker]:
# Supported in this sense means we can support graph breaks under the
# context.
# We are able to trace custom modes but if there are graph breaks under them
# and they have a custom __enter__/__exit__ we don't handle this for the
# same reason we don't handle generic context managers: there may be side effects
# that are now affected by executing the function across two frames instead of one
# Today we support the enter/exit of the default TorchFunctionMode as well as
# DeviceContext (which is used for set_default_device)
return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or (
not class_has_getattribute(ty)
and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__
and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__
)
def __init__(self, value, source=None, **kwargs):
if value is not None:
super().__init__(value, **kwargs)
self.value = value
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen: "PyCodegen"):
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
def module_name(self):
return self.value.__module__
def fn_name(self):
return type(self.value).__name__
def python_type(self):
return type(self.value)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
get_torch_function_fn(tx, self),
fn,
types,
args,
kwargs,
)
def enter(self, tx):
from .torch import TorchInGraphFunctionVariable
if isinstance(self.value, NoEnterTorchFunctionMode):
return ConstantVariable.create(None)
TorchInGraphFunctionVariable(
torch._C._push_on_torch_function_stack
).call_function(tx, [self], {})
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
from .torch import TorchInGraphFunctionVariable
TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function(
tx, [], {}
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen: "PyCodegen"):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(
codegen.tx.import_source(ty.__module__),
ty.__name__,
)
)
def supports_graph_breaks(self):
return True
def exit_on_graph_break(self):
return False
def _get_all_args(args, kwargs):
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
def _flatten_vts(vts): def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]:
from collections import deque from collections import deque
from .dicts import ConstDictVariable from .dicts import ConstDictVariable
@ -391,7 +422,7 @@ def _flatten_vts(vts):
while vts: while vts:
vt = vts.popleft() vt = vts.popleft()
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined]
vt.realize() vt.realize()
if vt.is_realized(): if vt.is_realized():
@ -407,21 +438,28 @@ def _flatten_vts(vts):
return output return output
def _get_subclass_type(var): def _get_subclass_type(var: VariableTracker) -> type:
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
return var.python_type() return var.python_type()
def _get_subclass_type_var(tx: "InstructionTranslator", var): def _get_subclass_type_var(
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) tx: "InstructionTranslator", var: VariableTracker
) -> VariableTracker:
if isinstance(var, TensorWithTFOverrideVariable): if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx) return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable): elif isinstance(var, UserDefinedObjectVariable):
source = var.source and TypeSource(var.source) source = var.source and TypeSource(var.source)
return VariableTracker.build(tx, var.python_type(), source) return VariableTracker.build(tx, var.python_type(), source)
else:
raise AssertionError(f"Unexpected type {type(var)}")
def _is_attr_overridden(tx: "InstructionTranslator", var, name): def _is_attr_overridden(
tx: "InstructionTranslator", var: VariableTracker, name: str
) -> bool:
if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)):
return False
import torch import torch
overridden = False overridden = False
@ -434,7 +472,14 @@ def _is_attr_overridden(tx: "InstructionTranslator", var, name):
return overridden return overridden
def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): def call_torch_function(
tx: "InstructionTranslator",
torch_function_var: VariableTracker,
fn: VariableTracker,
types: TupleVariable,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> Any:
# This emulates calling __torch_function__, which has a signature # This emulates calling __torch_function__, which has a signature
# def __torch_function__(cls, func, types, args=(), kwargs=None): # def __torch_function__(cls, func, types, args=(), kwargs=None):
# #
@ -451,7 +496,9 @@ def call_torch_function(tx, torch_function_var, fn, types, args, kwargs):
return torch_function_var.call_function(tx, tf_args, {}) return torch_function_var.call_function(tx, tf_args, {})
def get_torch_function_fn(tx: "InstructionTranslator", vt): def get_torch_function_fn(
tx: "InstructionTranslator", vt: VariableTracker
) -> VariableTracker:
# The underlying function could be a classmethod, staticmethod, regular # The underlying function could be a classmethod, staticmethod, regular
# function or a function with C-implementation. It doesn't matter as long as # function or a function with C-implementation. It doesn't matter as long as
# they satisfy the calling convention in `call_torch_function`. # they satisfy the calling convention in `call_torch_function`.
@ -462,7 +509,9 @@ def get_torch_function_fn(tx: "InstructionTranslator", vt):
return func_vt return func_vt
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): def can_dispatch_torch_function(
tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any]
) -> bool:
has_overridden_args = any( has_overridden_args = any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs) has_torch_function(arg) for arg in _get_all_args(args, kwargs)
) )
@ -472,7 +521,12 @@ def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
) )
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): def dispatch_torch_function(
tx: "InstructionTranslator",
fn: VariableTracker,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> Any:
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
all_args = _get_all_args(args, kwargs) all_args = _get_all_args(args, kwargs)
@ -518,7 +572,13 @@ class TensorWithTFOverrideVariable(TensorVariable):
""" """
@classmethod @classmethod
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): def from_tensor_var(
cls,
tx: "InstructionTranslator",
tensor_var: VariableTracker,
class_type: type,
cls_source: Source,
) -> "TensorWithTFOverrideVariable":
# [Note: __torch_function__] coerce `tensor_var` into a # [Note: __torch_function__] coerce `tensor_var` into a
# TensorWithTFOverrideVariable. In eager, this is just a type change. # TensorWithTFOverrideVariable. In eager, this is just a type change.
import torch import torch
@ -533,7 +593,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
var.install_global(tx) var.install_global(tx)
return var return var
def install_global(self, tx): def install_global(self, tx: "InstructionTranslator") -> None:
# stash the subclass type to rewrap an output tensor if needed # stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available # this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor. # each time the compiled artifact is run and outputs a wrapped tensor.
@ -543,20 +603,20 @@ class TensorWithTFOverrideVariable(TensorVariable):
self.global_mangled_class_name(tx), self.class_type self.global_mangled_class_name(tx), self.class_type
) )
def python_type(self): def python_type(self) -> type:
return self.class_type return self.class_type
def class_type_var(self, tx): def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker:
return TensorSubclassVariable( return TensorSubclassVariable(
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
) )
def global_mangled_class_name(self, tx): def global_mangled_class_name(self, tx: "InstructionTranslator") -> str:
return get_safe_global_name( return get_safe_global_name(
tx, f"__subclass_{self.class_type.__name__}", self.class_type tx, f"__subclass_{self.class_type.__name__}", self.class_type
) )
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
# [Note: __torch_function__] We currently only support attributes that are defined on # [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break. # base tensors, custom attribute accesses will graph break.
import torch import torch
@ -581,7 +641,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
and not attr_is_overridden and not attr_is_overridden
and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) and not inspect.ismethoddescriptor(getattr(torch.Tensor, name))
): ):
args, kwargs = [self], {} args = [self]
kwargs: dict[Any, Any] = {}
if can_dispatch_torch_function(tx, args, kwargs): if can_dispatch_torch_function(tx, args, kwargs):
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
@ -636,7 +697,14 @@ class TensorWithTFOverrideVariable(TensorVariable):
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): def call_torch_function(
self,
tx: "InstructionTranslator",
fn: VariableTracker,
types: TupleVariable,
args: Iterable[Any],
kwargs: dict[str, Any],
) -> Any:
# NOTE this assumes `__torch_function__` isn't modified during tracing. # NOTE this assumes `__torch_function__` isn't modified during tracing.
if not hasattr(self, "torch_function_fn"): if not hasattr(self, "torch_function_fn"):
self.torch_function_fn = get_torch_function_fn(tx, self) self.torch_function_fn = get_torch_function_fn(tx, self)
@ -652,8 +720,8 @@ class TensorWithTFOverrideVariable(TensorVariable):
def call_method( def call_method(
self, self,
tx, tx: "InstructionTranslator",
name, name: str,
args: "list[VariableTracker]", args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]", kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":