diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 05dd42866e8..b065c188bcb 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -18,7 +18,7 @@ import re import sys import types from collections import Counter -from typing import Optional, Union +from typing import Optional, TYPE_CHECKING, Union import torch.nn from torch.utils._ordered_set import OrderedSet @@ -54,6 +54,10 @@ from .variables.tensor import ( from .variables.torch_function import TensorWithTFOverrideVariable +if TYPE_CHECKING: + from .symbolic_convert import InstructionTranslatorBase + + @dataclasses.dataclass class GraphOutputEntry: index: int @@ -67,7 +71,7 @@ class PyCodegen: def __init__( self, - tx=None, + tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, tempvars=None, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index c11e6deccc7..92a6ea2f15c 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -390,7 +390,7 @@ class OutputGraph: # and LOAD_ATTR for same python objects free. self.variable_tracker_cache = VariableTrackerCache() self.unique_var_id = itertools.count() - self.code_options = dict(code_options) + self.code_options: dict[str, Any] = dict(code_options) self.output_instructions: list[Instruction] = [] # used to track nodes that are added between calls of copy_graphstate # and restore_graphstate @@ -401,7 +401,7 @@ class OutputGraph: # Not checkpointed self.compiler_fn: Optional[CompilerFn] = compiler_fn - self.global_scope = global_scope + self.global_scope: Scope = global_scope self.local_scope = local_scope self.root_tx = root_tx @@ -462,7 +462,7 @@ class OutputGraph: self.random_calls: list[ tuple[Callable[..., object], tuple[object, ...], dict[str, object]] ] = [] - self.random_values_var = None + self.random_values_var: Any = None # Bytecode to insert right before we call the graph self.pregraph_bytecode: list[Instruction] = [] @@ -888,7 +888,9 @@ class OutputGraph: self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableTracker.build( - self, target, ConstantSource(source_name=module_key) + self, # type: ignore[arg-type] + target, + ConstantSource(source_name=module_key), ) for k, v in self.nn_modules.items(): diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 4116f110b21..f31d613170a 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -21,7 +21,7 @@ the code needed to recreate values. import dataclasses import enum -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union from torch._guards import ChainedSource, GuardSource, Source @@ -29,6 +29,9 @@ from . import utils from .bytecode_transformation import create_call_function, create_instruction +if TYPE_CHECKING: + from .codegen import PyCodegen + # It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, # so those cases are omitted intentionally @@ -120,7 +123,7 @@ class LocalSource(Source): # or `co_freevars`. is_derefed_cell_contents: bool = False - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.is_derefed_cell_contents: codegen.load_deref(self.local_name) else: @@ -137,7 +140,7 @@ class LocalSource(Source): class SyntheticLocalSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(self.local_name)) def guard_source(self): @@ -154,7 +157,7 @@ class RandomValueSource(Source): def guard_source(self): return GuardSource.RANDOM_VALUE - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) @@ -167,7 +170,7 @@ class RandomValueSource(Source): class GlobalSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.global_name, add=True)) def guard_source(self): @@ -181,7 +184,7 @@ class GlobalSource(Source): class GlobalWeakRefSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_global(self.global_name, add=True) @@ -198,7 +201,7 @@ class GlobalWeakRefSource(Source): @dataclasses.dataclass(frozen=True) class WeakRefCallSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(self.base)) codegen.extend_output(create_call_function(0, False)) @@ -227,7 +230,7 @@ class AttrSource(ChainedSource): ) object.__setattr__(self, "member", member_parts[-1]) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -249,7 +252,7 @@ class LocalCellSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, # Dynamo's bytecode transformation differentiates them slightly, so we # always emit `LOAD_CLOSURE` here. @@ -267,7 +270,7 @@ class LocalCellSource(Source): class GradSource(ChainedSource): member: str = "grad" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -342,7 +345,7 @@ class TensorPropertySource(ChainedSource): else: assert self.idx is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, f"call_{self.prop.method_name()}" @@ -378,7 +381,7 @@ class IndexedSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -393,7 +396,7 @@ class NegateSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -450,7 +453,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource): class AttrProxySource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -484,7 +487,7 @@ class DefaultsSource(ChainedSource): self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.field)) codegen.append_output(codegen.create_load_const(self.idx_key)) @@ -509,7 +512,7 @@ class GetItemSource(ChainedSource): super().__setattr__("index", self.index.__reduce__()) super().__setattr__("index_is_slice", True) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) if self.index_is_slice: codegen.append_output(codegen.create_load_const(self.unpack_slice())) @@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") ) @@ -577,7 +580,7 @@ class DictGetItemSource(ChainedSource): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # reconstruct dict.__getitem__(dct, key) # Load dict.__getitem__ @@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource): Same as GetItemSource with reconstruct and name overridden to be list specific. """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Reconstruct list.__getitem__(lst, index) to avoid any side effects # from possibly overridden __getitem__. @@ -646,7 +649,7 @@ class ListGetItemSource(GetItemSource): @dataclasses.dataclass(frozen=True) class TupleIteratorGetItemSource(GetItemSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") ) @@ -663,7 +666,7 @@ class TypeSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) @@ -677,7 +680,7 @@ class TypeSource(ChainedSource): @dataclasses.dataclass(frozen=True) class OptimizerSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -689,7 +692,7 @@ class OptimizerSource(ChainedSource): @dataclasses.dataclass(frozen=True) class NNModuleSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -738,7 +741,7 @@ class TorchFunctionModeStackSource(Source): return TorchFunctionModeStackVariable.get_mode_index(self.ind) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, "get_torch_function_mode_stack_at" @@ -755,7 +758,7 @@ class TorchFunctionModeStackSource(Source): class ConstantSource(Source): source_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.source_name, add=False)) def guard_source(self): @@ -776,7 +779,7 @@ class NumpyTensorSource(ChainedSource): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index fbf780bf7fa..e5274d0f0ce 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -29,7 +29,8 @@ from ..utils import cmp_name_to_op_mapping, istype if TYPE_CHECKING: - from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase class SourceType(Enum): @@ -399,7 +400,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: return None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def unpack_var_sequence(self, tx) -> list["VariableTracker"]: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d5cea823b7f..d85885449b0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -276,6 +276,7 @@ except ModuleNotFoundError: if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -348,7 +349,7 @@ class GraphArg: self._example = TensorWeakRef(self._example) assert is_fake(self.fake_tensor) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.source) def erase(self): @@ -369,7 +370,7 @@ class BackwardStateGraphArg(GraphArg): is_tensor=False, ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert codegen.tx.output.backward_state_var codegen.add_push_null( lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5360868dd7e..2a7d031b7b8 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -87,6 +87,7 @@ from .user_defined import UserDefinedObjectVariable, UserDefinedVariable if TYPE_CHECKING: # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator log = logging.getLogger(__name__) @@ -730,7 +731,7 @@ class BuiltinVariable(VariableTracker): return DTYPE[self.fn] return super().as_proxy() - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"): + def reconstruct(self, codegen: "PyCodegen"): name = self.fn.__name__ assert self.fn.__module__ == "builtins" assert name not in codegen.tx.f_globals, "shadowed global" diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 6760bd1ff73..f86d2d2062a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -133,7 +133,7 @@ its type to `common_constant_types`. def call_method( self, - tx, + tx: "InstructionTranslator", name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 04f552c54fa..7cbed617d82 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -50,6 +50,7 @@ from .user_defined import UserDefinedObjectVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -85,12 +86,12 @@ class ContextWrappingVariable(VariableTracker): self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -1057,7 +1058,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable): _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1278,7 +1279,7 @@ class StreamVariable(VariableTracker): def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1340,7 +1341,7 @@ class EventVariable(VariableTracker): def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1378,7 +1379,7 @@ class WithExitFunctionVariable(VariableTracker): assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 60ae7744461..7c38539bd21 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -44,6 +44,7 @@ from .constant import ConstantVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -263,7 +264,7 @@ class ConstDictVariable(VariableTracker): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.add_push_null( @@ -546,7 +547,7 @@ class MappingProxyVariable(VariableTracker): def unpack_var_sequence(self, tx): return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # load types.MappingProxyType if self.source: unimplemented( @@ -681,7 +682,7 @@ class SetVariable(ConstDictVariable): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) @@ -786,7 +787,7 @@ class FrozensetVariable(SetVariable): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -879,7 +880,7 @@ class DictViewVariable(VariableTracker): def unpack_var_sequence(self, tx): return self.view_items_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 257ccac4d37..d8beec6aaeb 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -75,6 +75,7 @@ except ModuleNotFoundError: if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, @@ -470,7 +471,7 @@ class LocalGeneratorObjectVariable(VariableTracker): __repr__ = __str__ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -1109,7 +1110,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): return result - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) @@ -1506,7 +1507,7 @@ class FunctoolsPartialVariable(VariableTracker): def python_type(self): return functools.partial - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -1962,7 +1963,7 @@ class TMADescriptorVariable(VariableTracker): self.element_size.as_proxy(), ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 502616c440e..3cf9c994ddc 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -34,6 +34,7 @@ from .constant import ConstantVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -249,7 +250,7 @@ class RepeatIteratorVariable(IteratorVariable): def next_variable(self, tx): return self.item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -279,7 +280,7 @@ class CountIteratorVariable(IteratorVariable): self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -425,7 +426,7 @@ class ZipVariable(IteratorVariable): self.index += 1 return variables.TupleVariable(args) - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] @@ -436,7 +437,7 @@ class ZipVariable(IteratorVariable): else: codegen(it) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) @@ -481,7 +482,7 @@ class MapVariable(ZipVariable): args = super().next_variable(tx) return self.fn.call_function(tx, args.items, {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) @@ -555,7 +556,7 @@ class FilterVariable(IteratorVariable): if pred_res.as_python_constant(): return item - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) @@ -565,7 +566,7 @@ class FilterVariable(IteratorVariable): else: codegen(self.iterable) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2c92599a8b2..1430dc912cb 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -57,6 +57,7 @@ from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObject if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -81,7 +82,7 @@ class SuperVariable(VariableTracker): # cls for a classmethod) self.objvar = objvar - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) codegen(self.typevar) if self.objvar is not None: @@ -331,7 +332,7 @@ class ExceptionVariable(VariableTracker): def set_context(self, context: "ExceptionVariable"): self.__context__ = context - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", self.exc_type.__name__) ) @@ -460,7 +461,7 @@ class ComptimeVariable(VariableTracker): Dynamo compile time """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError("comptime is special form") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": @@ -944,7 +945,7 @@ class GetAttrVariable(VariableTracker): raise NotImplementedError return inspect.getattr_static(step2, name) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.obj) codegen.extend_output(codegen.create_load_attrs(self.name)) @@ -1161,7 +1162,7 @@ class TypingVariable(VariableTracker): def as_python_constant(self): return self.value - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen") -> None: # We're just trying to load the type here. Reconstructing the type from # scratch is tricky - for a type like `typing.List[int]` we'd need to # deconstruct the origin and args. The origin for `List[int]` is `list` @@ -1336,7 +1337,7 @@ class NullVariable(VariableTracker): def __repr__(self) -> str: return "NullVariable" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if sys.version_info < (3, 11): unimplemented("cannot reconstruct NullVariable in < Python 3.11") codegen.append_output(create_instruction("PUSH_NULL")) @@ -1377,7 +1378,7 @@ class StringFormatVariable(VariableTracker): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1426,7 +1427,7 @@ class DebuggingVariable(VariableTracker): tx.debug_locals.append((self, list(args))) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): return self.source.reconstruct(codegen) @staticmethod @@ -1721,7 +1722,7 @@ class RandomVariable(VariableTracker): return call_random_fn(tx, call_random_meth, args, kwargs) return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1762,7 +1763,7 @@ class WeakRefVariable(VariableTracker): ) -> "VariableTracker": return self.referent_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) codegen(self.referent_vt) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 51c1ea6bf14..6edd4a7c8ea 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -10,6 +10,7 @@ from .base import VariableTracker if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() @@ -36,7 +37,7 @@ class SDPAParamsVariable(VariableTracker): self.param_vars = param_vars super().__init__(**kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert self.source is None assert self.param_vars is not None codegen.add_push_null( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index c477979fa9e..ef6a69ceee7 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -80,6 +80,7 @@ except ModuleNotFoundError: if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -1558,7 +1559,7 @@ class UntypedStorageVariable(VariableTracker): return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("untyped_storage") codegen.call_method(0) @@ -1573,7 +1574,7 @@ class DataPtrVariable(VariableTracker): super().__init__(**kwargs) self.from_tensor = from_tensor - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("data_ptr") codegen.call_method(0) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 8034f440e77..429b3b57277 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -223,7 +223,7 @@ class BaseTorchVariable(VariableTracker): super().__init__(**kwargs) self.value = value - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): try: name = f"{self.value.__module__}.{self.value.__name__}" except Exception: diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 330faf9bf90..982a6511771 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -67,6 +67,7 @@ from .user_defined import UserDefinedObjectVariable if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -382,7 +383,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable): self.cm_obj = value # needed for BC with calling enter from CM code self.source = source - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) @@ -426,7 +427,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable): ) return ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): ty = NoEnterTorchFunctionMode codegen( AttrSource( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2d22e0d3580..fc39d238f30 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -97,6 +97,7 @@ except ImportError: if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -1507,7 +1508,7 @@ class RemovableHandleVariable(VariableTracker): return variables.ConstantVariable.create(None) super().call_method(tx, method_name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.idx == self.REMOVED: # Hook has already been removed, return a dummy handle codegen.add_push_null(