[Dynamo][Misc] Apply typing hints for codegen (#150289)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150289
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
Yuanhao Ji 2025-04-04 14:26:22 +00:00 committed by PyTorch MergeBot
parent 295b7e21eb
commit 1b0a023dde
17 changed files with 101 additions and 80 deletions

View File

@ -18,7 +18,7 @@ import re
import sys
import types
from collections import Counter
from typing import Optional, Union
from typing import Optional, TYPE_CHECKING, Union
import torch.nn
from torch.utils._ordered_set import OrderedSet
@ -54,6 +54,10 @@ from .variables.tensor import (
from .variables.torch_function import TensorWithTFOverrideVariable
if TYPE_CHECKING:
from .symbolic_convert import InstructionTranslatorBase
@dataclasses.dataclass
class GraphOutputEntry:
index: int
@ -67,7 +71,7 @@ class PyCodegen:
def __init__(
self,
tx=None,
tx: "InstructionTranslatorBase",
root: Optional[torch.nn.Module] = None,
graph_output_var: Optional[str] = None,
tempvars=None,

View File

@ -390,7 +390,7 @@ class OutputGraph:
# and LOAD_ATTR for same python objects free.
self.variable_tracker_cache = VariableTrackerCache()
self.unique_var_id = itertools.count()
self.code_options = dict(code_options)
self.code_options: dict[str, Any] = dict(code_options)
self.output_instructions: list[Instruction] = []
# used to track nodes that are added between calls of copy_graphstate
# and restore_graphstate
@ -401,7 +401,7 @@ class OutputGraph:
# Not checkpointed
self.compiler_fn: Optional[CompilerFn] = compiler_fn
self.global_scope = global_scope
self.global_scope: Scope = global_scope
self.local_scope = local_scope
self.root_tx = root_tx
@ -462,7 +462,7 @@ class OutputGraph:
self.random_calls: list[
tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
] = []
self.random_values_var = None
self.random_values_var: Any = None
# Bytecode to insert right before we call the graph
self.pregraph_bytecode: list[Instruction] = []
@ -888,7 +888,9 @@ class OutputGraph:
self.output.update_co_names(module_key)
self.global_scope[module_key] = target
return VariableTracker.build(
self, target, ConstantSource(source_name=module_key)
self, # type: ignore[arg-type]
target,
ConstantSource(source_name=module_key),
)
for k, v in self.nn_modules.items():

View File

@ -21,7 +21,7 @@ the code needed to recreate values.
import dataclasses
import enum
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union
from torch._guards import ChainedSource, GuardSource, Source
@ -29,6 +29,9 @@ from . import utils
from .bytecode_transformation import create_call_function, create_instruction
if TYPE_CHECKING:
from .codegen import PyCodegen
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
# so those cases are omitted intentionally
@ -120,7 +123,7 @@ class LocalSource(Source):
# or `co_freevars`.
is_derefed_cell_contents: bool = False
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
if self.is_derefed_cell_contents:
codegen.load_deref(self.local_name)
else:
@ -137,7 +140,7 @@ class LocalSource(Source):
class SyntheticLocalSource(Source):
local_name: str
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self):
@ -154,7 +157,7 @@ class RandomValueSource(Source):
def guard_source(self):
return GuardSource.RANDOM_VALUE
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
codegen.append_output(codegen.create_load_const(self.random_call_index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
@ -167,7 +170,7 @@ class RandomValueSource(Source):
class GlobalSource(Source):
global_name: str
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
def guard_source(self):
@ -181,7 +184,7 @@ class GlobalSource(Source):
class GlobalWeakRefSource(Source):
global_name: str
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_global(self.global_name, add=True)
@ -198,7 +201,7 @@ class GlobalWeakRefSource(Source):
@dataclasses.dataclass(frozen=True)
class WeakRefCallSource(ChainedSource):
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen(self.base))
codegen.extend_output(create_call_function(0, False))
@ -227,7 +230,7 @@ class AttrSource(ChainedSource):
)
object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))
@ -249,7 +252,7 @@ class LocalCellSource(Source):
local_name: str
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
# Dynamo's bytecode transformation differentiates them slightly, so we
# always emit `LOAD_CLOSURE` here.
@ -267,7 +270,7 @@ class LocalCellSource(Source):
class GradSource(ChainedSource):
member: str = "grad"
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member))
@ -342,7 +345,7 @@ class TensorPropertySource(ChainedSource):
else:
assert self.idx is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(
utils.__name__, f"call_{self.prop.method_name()}"
@ -378,7 +381,7 @@ class IndexedSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
raise NotImplementedError
def guard_source(self):
@ -393,7 +396,7 @@ class NegateSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
raise NotImplementedError
def guard_source(self):
@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -450,7 +453,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
class AttrProxySource(ChainedSource):
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -484,7 +487,7 @@ class DefaultsSource(ChainedSource):
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.field))
codegen.append_output(codegen.create_load_const(self.idx_key))
@ -509,7 +512,7 @@ class GetItemSource(ChainedSource):
super().__setattr__("index", self.index.__reduce__())
super().__setattr__("index_is_slice", True)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
if self.index_is_slice:
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource):
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
)
@ -577,7 +580,7 @@ class DictGetItemSource(ChainedSource):
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# reconstruct dict.__getitem__(dct, key)
# Load dict.__getitem__
@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource):
Same as GetItemSource with reconstruct and name overridden to be list specific.
"""
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
# from possibly overridden __getitem__.
@ -646,7 +649,7 @@ class ListGetItemSource(GetItemSource):
@dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource):
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
)
@ -663,7 +666,7 @@ class TypeSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
codegen(self.base)
codegen.extend_output(create_call_function(1, False))
@ -677,7 +680,7 @@ class TypeSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class OptimizerSource(ChainedSource):
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -689,7 +692,7 @@ class OptimizerSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource):
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.base)
def guard_source(self):
@ -738,7 +741,7 @@ class TorchFunctionModeStackSource(Source):
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(
utils.__name__, "get_torch_function_mode_stack_at"
@ -755,7 +758,7 @@ class TorchFunctionModeStackSource(Source):
class ConstantSource(Source):
source_name: str
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
def guard_source(self):
@ -776,7 +779,7 @@ class NumpyTensorSource(ChainedSource):
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
codegen(self.base)
codegen.extend_output(create_call_function(1, False))

View File

@ -29,7 +29,8 @@ from ..utils import cmp_name_to_op_mapping, istype
if TYPE_CHECKING:
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
from ..codegen import PyCodegen
from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase
class SourceType(Enum):
@ -399,7 +400,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
except NotImplementedError:
return None
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
raise NotImplementedError
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:

View File

@ -276,6 +276,7 @@ except ModuleNotFoundError:
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -348,7 +349,7 @@ class GraphArg:
self._example = TensorWeakRef(self._example)
assert is_fake(self.fake_tensor)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.source)
def erase(self):
@ -369,7 +370,7 @@ class BackwardStateGraphArg(GraphArg):
is_tensor=False,
)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
assert codegen.tx.output.backward_state_var
codegen.add_push_null(
lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")

View File

@ -87,6 +87,7 @@ from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
if TYPE_CHECKING:
# Cyclic dependency...
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
@ -730,7 +731,7 @@ class BuiltinVariable(VariableTracker):
return DTYPE[self.fn]
return super().as_proxy()
def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"):
def reconstruct(self, codegen: "PyCodegen"):
name = self.fn.__name__
assert self.fn.__module__ == "builtins"
assert name not in codegen.tx.f_globals, "shadowed global"

View File

@ -133,7 +133,7 @@ its type to `common_constant_types`.
def call_method(
self,
tx,
tx: "InstructionTranslator",
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",

View File

@ -50,6 +50,7 @@ from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -85,12 +86,12 @@ class ContextWrappingVariable(VariableTracker):
self.cleanup_assert()
return variables.ConstantVariable.create(None)
def reconstruct_type(self, codegen):
def reconstruct_type(self, codegen: "PyCodegen"):
codegen(
AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: self.reconstruct_type(codegen))
target_values = self.target_values
if not target_values:
@ -1057,7 +1058,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
_unsafe_set_version_counter
).call_function(tx, [self.tensors, self.prev_versions], {})
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
unimplemented_v2(
gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
context=str(self),
@ -1278,7 +1279,7 @@ class StreamVariable(VariableTracker):
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
@ -1340,7 +1341,7 @@ class EventVariable(VariableTracker):
def as_proxy(self):
return self.proxy
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
@ -1378,7 +1379,7 @@ class WithExitFunctionVariable(VariableTracker):
assert not kwargs
return self.ctx.exit(tx, *args)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.

View File

@ -44,6 +44,7 @@ from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -263,7 +264,7 @@ class ConstDictVariable(VariableTracker):
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# instructions to load collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
codegen.add_push_null(
@ -546,7 +547,7 @@ class MappingProxyVariable(VariableTracker):
def unpack_var_sequence(self, tx):
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# load types.MappingProxyType
if self.source:
unimplemented(
@ -681,7 +682,7 @@ class SetVariable(ConstDictVariable):
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
@ -786,7 +787,7 @@ class FrozensetVariable(SetVariable):
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -879,7 +880,7 @@ class DictViewVariable(VariableTracker):
def unpack_var_sequence(self, tx):
return self.view_items_vt
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)

View File

@ -75,6 +75,7 @@ except ModuleNotFoundError:
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._higher_order_ops.triton_kernel_wrap import (
TritonGridType,
@ -470,7 +471,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
__repr__ = __str__
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
from torch._dynamo.side_effects import disallow_side_effects_in_generator
from torch._dynamo.symbolic_convert import (
InstructionTranslator,
@ -1109,7 +1110,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
return result
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(__name__, "_create_nested_fn")
)
@ -1506,7 +1507,7 @@ class FunctoolsPartialVariable(VariableTracker):
def python_type(self):
return functools.partial
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
codegen(self.func)
if self.args:
@ -1962,7 +1963,7 @@ class TMADescriptorVariable(VariableTracker):
self.element_size.as_proxy(),
)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from(
"triton.tools.experimental_descriptor",

View File

@ -34,6 +34,7 @@ from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -249,7 +250,7 @@ class RepeatIteratorVariable(IteratorVariable):
def next_variable(self, tx):
return self.item
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.extend_output(
[
@ -279,7 +280,7 @@ class CountIteratorVariable(IteratorVariable):
self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.extend_output(
[
@ -425,7 +426,7 @@ class ZipVariable(IteratorVariable):
self.index += 1
return variables.TupleVariable(args)
def reconstruct_items(self, codegen):
def reconstruct_items(self, codegen: "PyCodegen"):
for it in self.iterables:
if isinstance(it, list):
remaining_items = it[self.index :]
@ -436,7 +437,7 @@ class ZipVariable(IteratorVariable):
else:
codegen(it)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
)
@ -481,7 +482,7 @@ class MapVariable(ZipVariable):
args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {})
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
)
@ -555,7 +556,7 @@ class FilterVariable(IteratorVariable):
if pred_res.as_python_constant():
return item
def reconstruct_items(self, codegen):
def reconstruct_items(self, codegen: "PyCodegen"):
if isinstance(self.iterable, list):
remaining_items = self.iterable[self.index :]
codegen.foreach(remaining_items)
@ -565,7 +566,7 @@ class FilterVariable(IteratorVariable):
else:
codegen(self.iterable)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter"))
codegen(self.fn)
self.reconstruct_items(codegen)

View File

@ -57,6 +57,7 @@ from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObject
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -81,7 +82,7 @@ class SuperVariable(VariableTracker):
# cls for a classmethod)
self.objvar = objvar
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
codegen(self.typevar)
if self.objvar is not None:
@ -331,7 +332,7 @@ class ExceptionVariable(VariableTracker):
def set_context(self, context: "ExceptionVariable"):
self.__context__ = context
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
)
@ -460,7 +461,7 @@ class ComptimeVariable(VariableTracker):
Dynamo compile time
"""
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
raise NotImplementedError("comptime is special form")
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
@ -944,7 +945,7 @@ class GetAttrVariable(VariableTracker):
raise NotImplementedError
return inspect.getattr_static(step2, name)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.obj)
codegen.extend_output(codegen.create_load_attrs(self.name))
@ -1161,7 +1162,7 @@ class TypingVariable(VariableTracker):
def as_python_constant(self):
return self.value
def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen") -> None:
# We're just trying to load the type here. Reconstructing the type from
# scratch is tricky - for a type like `typing.List[int]` we'd need to
# deconstruct the origin and args. The origin for `List[int]` is `list`
@ -1336,7 +1337,7 @@ class NullVariable(VariableTracker):
def __repr__(self) -> str:
return "NullVariable"
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
if sys.version_info < (3, 11):
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1377,7 +1378,7 @@ class StringFormatVariable(VariableTracker):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.extend_output(
[
@ -1426,7 +1427,7 @@ class DebuggingVariable(VariableTracker):
tx.debug_locals.append((self, list(args)))
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
return self.source.reconstruct(codegen)
@staticmethod
@ -1721,7 +1722,7 @@ class RandomVariable(VariableTracker):
return call_random_fn(tx, call_random_meth, args, kwargs)
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(
lambda: codegen.extend_output(
[
@ -1762,7 +1763,7 @@ class WeakRefVariable(VariableTracker):
) -> "VariableTracker":
return self.referent_vt
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref"))
codegen(self.referent_vt)
codegen.extend_output(create_call_function(1, False))

View File

@ -10,6 +10,7 @@ from .base import VariableTracker
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
@ -36,7 +37,7 @@ class SDPAParamsVariable(VariableTracker):
self.param_vars = param_vars
super().__init__(**kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
assert self.source is None
assert self.param_vars is not None
codegen.add_push_null(

View File

@ -80,6 +80,7 @@ except ModuleNotFoundError:
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -1558,7 +1559,7 @@ class UntypedStorageVariable(VariableTracker):
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.from_tensor)
codegen.load_method("untyped_storage")
codegen.call_method(0)
@ -1573,7 +1574,7 @@ class DataPtrVariable(VariableTracker):
super().__init__(**kwargs)
self.from_tensor = from_tensor
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.from_tensor)
codegen.load_method("data_ptr")
codegen.call_method(0)

View File

@ -223,7 +223,7 @@ class BaseTorchVariable(VariableTracker):
super().__init__(**kwargs)
self.value = value
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
try:
name = f"{self.value.__module__}.{self.value.__name__}"
except Exception:

View File

@ -67,6 +67,7 @@ from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -382,7 +383,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
self.cm_obj = value # needed for BC with calling enter from CM code
self.source = source
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
# This shouldn't be called unless we have a source
assert self.source
self.source.reconstruct(codegen)
@ -426,7 +427,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
)
return ConstantVariable.create(None)
def reconstruct_type(self, codegen):
def reconstruct_type(self, codegen: "PyCodegen"):
ty = NoEnterTorchFunctionMode
codegen(
AttrSource(

View File

@ -97,6 +97,7 @@ except ImportError:
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -1507,7 +1508,7 @@ class RemovableHandleVariable(VariableTracker):
return variables.ConstantVariable.create(None)
super().call_method(tx, method_name, args, kwargs)
def reconstruct(self, codegen):
def reconstruct(self, codegen: "PyCodegen"):
if self.idx == self.REMOVED:
# Hook has already been removed, return a dummy handle
codegen.add_push_null(