mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Dynamo][Misc] Apply typing hints for codegen (#150289)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150289 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
parent
295b7e21eb
commit
1b0a023dde
|
|
@ -18,7 +18,7 @@ import re
|
|||
import sys
|
||||
import types
|
||||
from collections import Counter
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch.nn
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -54,6 +54,10 @@ from .variables.tensor import (
|
|||
from .variables.torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GraphOutputEntry:
|
||||
index: int
|
||||
|
|
@ -67,7 +71,7 @@ class PyCodegen:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
tx=None,
|
||||
tx: "InstructionTranslatorBase",
|
||||
root: Optional[torch.nn.Module] = None,
|
||||
graph_output_var: Optional[str] = None,
|
||||
tempvars=None,
|
||||
|
|
|
|||
|
|
@ -390,7 +390,7 @@ class OutputGraph:
|
|||
# and LOAD_ATTR for same python objects free.
|
||||
self.variable_tracker_cache = VariableTrackerCache()
|
||||
self.unique_var_id = itertools.count()
|
||||
self.code_options = dict(code_options)
|
||||
self.code_options: dict[str, Any] = dict(code_options)
|
||||
self.output_instructions: list[Instruction] = []
|
||||
# used to track nodes that are added between calls of copy_graphstate
|
||||
# and restore_graphstate
|
||||
|
|
@ -401,7 +401,7 @@ class OutputGraph:
|
|||
|
||||
# Not checkpointed
|
||||
self.compiler_fn: Optional[CompilerFn] = compiler_fn
|
||||
self.global_scope = global_scope
|
||||
self.global_scope: Scope = global_scope
|
||||
self.local_scope = local_scope
|
||||
self.root_tx = root_tx
|
||||
|
||||
|
|
@ -462,7 +462,7 @@ class OutputGraph:
|
|||
self.random_calls: list[
|
||||
tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
|
||||
] = []
|
||||
self.random_values_var = None
|
||||
self.random_values_var: Any = None
|
||||
|
||||
# Bytecode to insert right before we call the graph
|
||||
self.pregraph_bytecode: list[Instruction] = []
|
||||
|
|
@ -888,7 +888,9 @@ class OutputGraph:
|
|||
self.output.update_co_names(module_key)
|
||||
self.global_scope[module_key] = target
|
||||
return VariableTracker.build(
|
||||
self, target, ConstantSource(source_name=module_key)
|
||||
self, # type: ignore[arg-type]
|
||||
target,
|
||||
ConstantSource(source_name=module_key),
|
||||
)
|
||||
|
||||
for k, v in self.nn_modules.items():
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ the code needed to recreate values.
|
|||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._guards import ChainedSource, GuardSource, Source
|
||||
|
||||
|
|
@ -29,6 +29,9 @@ from . import utils
|
|||
from .bytecode_transformation import create_call_function, create_instruction
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .codegen import PyCodegen
|
||||
|
||||
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
|
||||
# so those cases are omitted intentionally
|
||||
|
||||
|
|
@ -120,7 +123,7 @@ class LocalSource(Source):
|
|||
# or `co_freevars`.
|
||||
is_derefed_cell_contents: bool = False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
if self.is_derefed_cell_contents:
|
||||
codegen.load_deref(self.local_name)
|
||||
else:
|
||||
|
|
@ -137,7 +140,7 @@ class LocalSource(Source):
|
|||
class SyntheticLocalSource(Source):
|
||||
local_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.append_output(codegen.create_load(self.local_name))
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -154,7 +157,7 @@ class RandomValueSource(Source):
|
|||
def guard_source(self):
|
||||
return GuardSource.RANDOM_VALUE
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
||||
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||
|
|
@ -167,7 +170,7 @@ class RandomValueSource(Source):
|
|||
class GlobalSource(Source):
|
||||
global_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -181,7 +184,7 @@ class GlobalSource(Source):
|
|||
class GlobalWeakRefSource(Source):
|
||||
global_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.append_output(
|
||||
codegen.create_load_global(self.global_name, add=True)
|
||||
|
|
@ -198,7 +201,7 @@ class GlobalWeakRefSource(Source):
|
|||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WeakRefCallSource(ChainedSource):
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen(self.base))
|
||||
codegen.extend_output(create_call_function(0, False))
|
||||
|
||||
|
|
@ -227,7 +230,7 @@ class AttrSource(ChainedSource):
|
|||
)
|
||||
object.__setattr__(self, "member", member_parts[-1])
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
|
|
@ -249,7 +252,7 @@ class LocalCellSource(Source):
|
|||
|
||||
local_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
||||
# Dynamo's bytecode transformation differentiates them slightly, so we
|
||||
# always emit `LOAD_CLOSURE` here.
|
||||
|
|
@ -267,7 +270,7 @@ class LocalCellSource(Source):
|
|||
class GradSource(ChainedSource):
|
||||
member: str = "grad"
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
|
|
@ -342,7 +345,7 @@ class TensorPropertySource(ChainedSource):
|
|||
else:
|
||||
assert self.idx is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
utils.__name__, f"call_{self.prop.method_name()}"
|
||||
|
|
@ -378,7 +381,7 @@ class IndexedSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -393,7 +396,7 @@ class NegateSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -450,7 +453,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource):
|
|||
|
||||
|
||||
class AttrProxySource(ChainedSource):
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -484,7 +487,7 @@ class DefaultsSource(ChainedSource):
|
|||
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.field))
|
||||
codegen.append_output(codegen.create_load_const(self.idx_key))
|
||||
|
|
@ -509,7 +512,7 @@ class GetItemSource(ChainedSource):
|
|||
super().__setattr__("index", self.index.__reduce__())
|
||||
super().__setattr__("index_is_slice", True)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
if self.index_is_slice:
|
||||
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
|
||||
|
|
@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource):
|
|||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
||||
)
|
||||
|
|
@ -577,7 +580,7 @@ class DictGetItemSource(ChainedSource):
|
|||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# reconstruct dict.__getitem__(dct, key)
|
||||
|
||||
# Load dict.__getitem__
|
||||
|
|
@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource):
|
|||
Same as GetItemSource with reconstruct and name overridden to be list specific.
|
||||
"""
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
|
||||
# from possibly overridden __getitem__.
|
||||
|
||||
|
|
@ -646,7 +649,7 @@ class ListGetItemSource(GetItemSource):
|
|||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TupleIteratorGetItemSource(GetItemSource):
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
||||
)
|
||||
|
|
@ -663,7 +666,7 @@ class TypeSource(ChainedSource):
|
|||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
|
@ -677,7 +680,7 @@ class TypeSource(ChainedSource):
|
|||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class OptimizerSource(ChainedSource):
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -689,7 +692,7 @@ class OptimizerSource(ChainedSource):
|
|||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NNModuleSource(ChainedSource):
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.base)
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -738,7 +741,7 @@ class TorchFunctionModeStackSource(Source):
|
|||
|
||||
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
utils.__name__, "get_torch_function_mode_stack_at"
|
||||
|
|
@ -755,7 +758,7 @@ class TorchFunctionModeStackSource(Source):
|
|||
class ConstantSource(Source):
|
||||
source_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
|
||||
|
||||
def guard_source(self):
|
||||
|
|
@ -776,7 +779,7 @@ class NumpyTensorSource(ChainedSource):
|
|||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ from ..utils import cmp_name_to_op_mapping, istype
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
|
||||
from ..codegen import PyCodegen
|
||||
from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase
|
||||
|
||||
|
||||
class SourceType(Enum):
|
||||
|
|
@ -399,7 +400,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx) -> list["VariableTracker"]:
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ except ModuleNotFoundError:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -348,7 +349,7 @@ class GraphArg:
|
|||
self._example = TensorWeakRef(self._example)
|
||||
assert is_fake(self.fake_tensor)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.source)
|
||||
|
||||
def erase(self):
|
||||
|
|
@ -369,7 +370,7 @@ class BackwardStateGraphArg(GraphArg):
|
|||
is_tensor=False,
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
assert codegen.tx.output.backward_state_var
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState")
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
|
|||
|
||||
if TYPE_CHECKING:
|
||||
# Cyclic dependency...
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -730,7 +731,7 @@ class BuiltinVariable(VariableTracker):
|
|||
return DTYPE[self.fn]
|
||||
return super().as_proxy()
|
||||
|
||||
def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
name = self.fn.__name__
|
||||
assert self.fn.__module__ == "builtins"
|
||||
assert name not in codegen.tx.f_globals, "shadowed global"
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ its type to `common_constant_types`.
|
|||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
tx: "InstructionTranslator",
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ from .user_defined import UserDefinedObjectVariable
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -85,12 +86,12 @@ class ContextWrappingVariable(VariableTracker):
|
|||
self.cleanup_assert()
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
codegen(
|
||||
AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: self.reconstruct_type(codegen))
|
||||
target_values = self.target_values
|
||||
if not target_values:
|
||||
|
|
@ -1057,7 +1058,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
|
|||
_unsafe_set_version_counter
|
||||
).call_function(tx, [self.tensors, self.prev_versions], {})
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
unimplemented_v2(
|
||||
gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
|
||||
context=str(self),
|
||||
|
|
@ -1278,7 +1279,7 @@ class StreamVariable(VariableTracker):
|
|||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# If we got here, this stream is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
|
|
@ -1340,7 +1341,7 @@ class EventVariable(VariableTracker):
|
|||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# If we got here, this event is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
|
|
@ -1378,7 +1379,7 @@ class WithExitFunctionVariable(VariableTracker):
|
|||
assert not kwargs
|
||||
return self.ctx.exit(tx, *args)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# Note here we reconstruct the context manager rather than the
|
||||
# exit function. The handler generated by BlockStackEntry
|
||||
# will re-enter the context in the resume function.
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from .constant import ConstantVariable
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -263,7 +264,7 @@ class ConstDictVariable(VariableTracker):
|
|||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# instructions to load collections.OrderedDict if necessary
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
codegen.add_push_null(
|
||||
|
|
@ -546,7 +547,7 @@ class MappingProxyVariable(VariableTracker):
|
|||
def unpack_var_sequence(self, tx):
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
unimplemented(
|
||||
|
|
@ -681,7 +682,7 @@ class SetVariable(ConstDictVariable):
|
|||
def as_python_constant(self):
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
|
|
@ -786,7 +787,7 @@ class FrozensetVariable(SetVariable):
|
|||
def as_python_constant(self):
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
|
|
@ -879,7 +880,7 @@ class DictViewVariable(VariableTracker):
|
|||
def unpack_var_sequence(self, tx):
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ except ModuleNotFoundError:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
TritonGridType,
|
||||
|
|
@ -470,7 +471,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
|||
|
||||
__repr__ = __str__
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
from torch._dynamo.side_effects import disallow_side_effects_in_generator
|
||||
from torch._dynamo.symbolic_convert import (
|
||||
InstructionTranslator,
|
||||
|
|
@ -1109,7 +1110,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
|
||||
return result
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(__name__, "_create_nested_fn")
|
||||
)
|
||||
|
|
@ -1506,7 +1507,7 @@ class FunctoolsPartialVariable(VariableTracker):
|
|||
def python_type(self):
|
||||
return functools.partial
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
|
||||
codegen(self.func)
|
||||
if self.args:
|
||||
|
|
@ -1962,7 +1963,7 @@ class TMADescriptorVariable(VariableTracker):
|
|||
self.element_size.as_proxy(),
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
"triton.tools.experimental_descriptor",
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from .constant import ConstantVariable
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -249,7 +250,7 @@ class RepeatIteratorVariable(IteratorVariable):
|
|||
def next_variable(self, tx):
|
||||
return self.item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
|
|
@ -279,7 +280,7 @@ class CountIteratorVariable(IteratorVariable):
|
|||
self.item = self.item.call_method(tx, "__add__", [self.step], {})
|
||||
return old_item
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
|
|
@ -425,7 +426,7 @@ class ZipVariable(IteratorVariable):
|
|||
self.index += 1
|
||||
return variables.TupleVariable(args)
|
||||
|
||||
def reconstruct_items(self, codegen):
|
||||
def reconstruct_items(self, codegen: "PyCodegen"):
|
||||
for it in self.iterables:
|
||||
if isinstance(it, list):
|
||||
remaining_items = it[self.index :]
|
||||
|
|
@ -436,7 +437,7 @@ class ZipVariable(IteratorVariable):
|
|||
else:
|
||||
codegen(it)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
|
||||
)
|
||||
|
|
@ -481,7 +482,7 @@ class MapVariable(ZipVariable):
|
|||
args = super().next_variable(tx)
|
||||
return self.fn.call_function(tx, args.items, {})
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
|
||||
)
|
||||
|
|
@ -555,7 +556,7 @@ class FilterVariable(IteratorVariable):
|
|||
if pred_res.as_python_constant():
|
||||
return item
|
||||
|
||||
def reconstruct_items(self, codegen):
|
||||
def reconstruct_items(self, codegen: "PyCodegen"):
|
||||
if isinstance(self.iterable, list):
|
||||
remaining_items = self.iterable[self.index :]
|
||||
codegen.foreach(remaining_items)
|
||||
|
|
@ -565,7 +566,7 @@ class FilterVariable(IteratorVariable):
|
|||
else:
|
||||
codegen(self.iterable)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter"))
|
||||
codegen(self.fn)
|
||||
self.reconstruct_items(codegen)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObject
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ class SuperVariable(VariableTracker):
|
|||
# cls for a classmethod)
|
||||
self.objvar = objvar
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
|
||||
codegen(self.typevar)
|
||||
if self.objvar is not None:
|
||||
|
|
@ -331,7 +332,7 @@ class ExceptionVariable(VariableTracker):
|
|||
def set_context(self, context: "ExceptionVariable"):
|
||||
self.__context__ = context
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
|
||||
)
|
||||
|
|
@ -460,7 +461,7 @@ class ComptimeVariable(VariableTracker):
|
|||
Dynamo compile time
|
||||
"""
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
raise NotImplementedError("comptime is special form")
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
|
|
@ -944,7 +945,7 @@ class GetAttrVariable(VariableTracker):
|
|||
raise NotImplementedError
|
||||
return inspect.getattr_static(step2, name)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.obj)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.name))
|
||||
|
||||
|
|
@ -1161,7 +1162,7 @@ class TypingVariable(VariableTracker):
|
|||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# We're just trying to load the type here. Reconstructing the type from
|
||||
# scratch is tricky - for a type like `typing.List[int]` we'd need to
|
||||
# deconstruct the origin and args. The origin for `List[int]` is `list`
|
||||
|
|
@ -1336,7 +1337,7 @@ class NullVariable(VariableTracker):
|
|||
def __repr__(self) -> str:
|
||||
return "NullVariable"
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
if sys.version_info < (3, 11):
|
||||
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
|
||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||
|
|
@ -1377,7 +1378,7 @@ class StringFormatVariable(VariableTracker):
|
|||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
|
|
@ -1426,7 +1427,7 @@ class DebuggingVariable(VariableTracker):
|
|||
|
||||
tx.debug_locals.append((self, list(args)))
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
return self.source.reconstruct(codegen)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1721,7 +1722,7 @@ class RandomVariable(VariableTracker):
|
|||
return call_random_fn(tx, call_random_meth, args, kwargs)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
|
|
@ -1762,7 +1763,7 @@ class WeakRefVariable(VariableTracker):
|
|||
) -> "VariableTracker":
|
||||
return self.referent_vt
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref"))
|
||||
codegen(self.referent_vt)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .base import VariableTracker
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
|
||||
|
|
@ -36,7 +37,7 @@ class SDPAParamsVariable(VariableTracker):
|
|||
self.param_vars = param_vars
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
assert self.source is None
|
||||
assert self.param_vars is not None
|
||||
codegen.add_push_null(
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ except ModuleNotFoundError:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -1558,7 +1559,7 @@ class UntypedStorageVariable(VariableTracker):
|
|||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.from_tensor)
|
||||
codegen.load_method("untyped_storage")
|
||||
codegen.call_method(0)
|
||||
|
|
@ -1573,7 +1574,7 @@ class DataPtrVariable(VariableTracker):
|
|||
super().__init__(**kwargs)
|
||||
self.from_tensor = from_tensor
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.from_tensor)
|
||||
codegen.load_method("data_ptr")
|
||||
codegen.call_method(0)
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class BaseTorchVariable(VariableTracker):
|
|||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
try:
|
||||
name = f"{self.value.__module__}.{self.value.__name__}"
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ from .user_defined import UserDefinedObjectVariable
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -382,7 +383,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
|||
self.cm_obj = value # needed for BC with calling enter from CM code
|
||||
self.source = source
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# This shouldn't be called unless we have a source
|
||||
assert self.source
|
||||
self.source.reconstruct(codegen)
|
||||
|
|
@ -426,7 +427,7 @@ class TorchFunctionModeVariable(GenericContextWrappingVariable):
|
|||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
def reconstruct_type(self, codegen: "PyCodegen"):
|
||||
ty = NoEnterTorchFunctionMode
|
||||
codegen(
|
||||
AttrSource(
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ except ImportError:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.codegen import PyCodegen
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
|
|
@ -1507,7 +1508,7 @@ class RemovableHandleVariable(VariableTracker):
|
|||
return variables.ConstantVariable.create(None)
|
||||
super().call_method(tx, method_name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
if self.idx == self.REMOVED:
|
||||
# Hook has already been removed, return a dummy handle
|
||||
codegen.add_push_null(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user