mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make variables in dict lazy and remove DICT_KEYS guard. We build the keys of a dict depth-first and we rely on the guards of each element in the dict to create the correct guards. This allows us to remove the rather buggy DICT_KEYS guard and make the guard lazy. The guards are not completely lazy yet, as we instantiate them in `_HashableTracker._eq_impl` but it should be possible to make them truly lazy. Also, adding new types to the supported types within keys should be less error prone. This is marginally less efficient when we graph break, but in turn we should graph break much less. It also makes the dicts code easier to maintain (removes `is_hashable_python_var`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/117625 Approved by: https://github.com/jansel, https://github.com/peterbell10, https://github.com/anijain2305 ghstack dependencies: #117982, #118098, #117983
516 lines
15 KiB
Python
516 lines
15 KiB
Python
import collections
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Optional, Union
|
|
|
|
from torch._guards import ChainedSource, GuardSource, Source
|
|
|
|
from . import utils
|
|
from .bytecode_transformation import create_call_function, create_instruction
|
|
from .utils import enum_repr
|
|
|
|
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
|
|
# so those cases are omitted intentionally
|
|
_GUARD_SOURCE_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_FSDP_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_NOT_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
|
|
}
|
|
|
|
|
|
def is_constant_source(source):
|
|
if isinstance(source, ConstantSource):
|
|
return True
|
|
try:
|
|
if source.guard_source() == GuardSource.CONSTANT:
|
|
return True
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def is_input_source(source):
|
|
return source.guard_source() in [
|
|
GuardSource.LOCAL,
|
|
GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_FSDP_MODULE,
|
|
]
|
|
|
|
|
|
def reconstruct_getitem(
|
|
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
|
|
):
|
|
instrs = source.base.reconstruct(codegen)
|
|
|
|
if isinstance(source.index, Source):
|
|
instrs.extend(source.index.reconstruct(codegen))
|
|
else:
|
|
if index_is_slice:
|
|
assert isinstance(source, GetItemSource)
|
|
instrs.append(codegen.create_load_const(source.unpack_slice()))
|
|
else:
|
|
instrs.append(codegen.create_load_const(source.index))
|
|
|
|
return instrs
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class LocalSource(Source):
|
|
local_name: str
|
|
cell_or_freevar: bool = False
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load(self.local_name)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.LOCAL
|
|
|
|
def name(self):
|
|
return f"L[{repr(self.local_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RandomValueSource(Source):
|
|
random_call_index: int
|
|
|
|
def guard_source(self):
|
|
return GuardSource.RANDOM_VALUE
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load(codegen.tx.output.random_values_var),
|
|
codegen.create_load_const(self.random_call_index),
|
|
create_instruction("BINARY_SUBSCR"),
|
|
]
|
|
|
|
def name(self):
|
|
return f"random_value_{self.random_call_index}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.global_name, False, add=True)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalWeakRefSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load_global(self.global_name, True, add=True),
|
|
*create_call_function(0, False),
|
|
]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class AttrSource(ChainedSource):
|
|
member: str
|
|
|
|
def __post_init__(self):
|
|
assert self.base, "Can't construct an AttrSource without a valid base source"
|
|
if "." in self.member:
|
|
member_parts = self.member.split(".")
|
|
object.__setattr__(
|
|
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
|
|
)
|
|
object.__setattr__(self, "member", member_parts[-1])
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if not self.member.isidentifier():
|
|
return f"getattr({self.base.name()}, {self.member!r})"
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ParamBufferSource(AttrSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
class TensorProperty(enum.Enum):
|
|
SIZE = 0
|
|
STRIDE = 1
|
|
STORAGE_OFFSET = 2
|
|
|
|
def method_name(self):
|
|
if self is TensorProperty.SIZE:
|
|
return "size"
|
|
elif self is TensorProperty.STRIDE:
|
|
return "stride"
|
|
elif self is TensorProperty.STORAGE_OFFSET:
|
|
return "storage_offset"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TensorPropertySource(ChainedSource):
|
|
prop: TensorProperty
|
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
else:
|
|
assert self.idx is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
instructions = [
|
|
*self.base.reconstruct(codegen),
|
|
codegen.create_load_attr(self.prop.method_name()),
|
|
]
|
|
if self.idx is not None:
|
|
instructions.append(codegen.create_load_const(self.idx))
|
|
instructions.extend(
|
|
create_call_function(1 if self.idx is not None else 0, True)
|
|
)
|
|
return instructions
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.prop is TensorProperty.SIZE:
|
|
return f"{self.base.name()}.size()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STRIDE:
|
|
return f"{self.base.name()}.stride()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
return f"{self.base.name()}.storage_offset()"
|
|
else:
|
|
raise AssertionError(f"unhandled {self.prop}")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NegateSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
# NB: use method call so that function stripping regexes work
|
|
return f"{self.base.name()}.__neg__()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConvertIntSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DefaultsSource(ChainedSource):
|
|
idx_key: Union[int, str]
|
|
is_kw: bool = False
|
|
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
|
|
def __post_init__(self):
|
|
assert (
|
|
self.base
|
|
), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
|
if self.is_kw:
|
|
assert isinstance(self.idx_key, str)
|
|
object.__setattr__(self, "field", "__kwdefaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
|
|
)
|
|
else:
|
|
assert isinstance(self.idx_key, int)
|
|
object.__setattr__(self, "field", "__defaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
|
)
|
|
|
|
def reconstruct(self, codegen):
|
|
instrs = self.base.reconstruct(codegen)
|
|
instrs.extend(codegen.create_load_attrs(self.field))
|
|
instrs.extend(
|
|
[
|
|
codegen.create_load_const(self.idx_key),
|
|
create_instruction("BINARY_SUBSCR"),
|
|
]
|
|
)
|
|
return instrs
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return self._name
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GetItemSource(ChainedSource):
|
|
index: Any
|
|
index_is_slice: bool = False
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if isinstance(self.index, slice):
|
|
# store the hashable version of the slice so the whole GetItemSource is hashable
|
|
super().__setattr__("index", self.index.__reduce__())
|
|
super().__setattr__("index_is_slice", True)
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
*reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice),
|
|
create_instruction("BINARY_SUBSCR"),
|
|
]
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def unpack_slice(self):
|
|
assert self.index_is_slice
|
|
slice_class, slice_args = self.index
|
|
return slice_class(*slice_args)
|
|
|
|
def name(self):
|
|
# Index can be of following types
|
|
# 1) ConstDictKeySource
|
|
# 2) enum.Enum
|
|
# 3) index is a slice - example 1:4
|
|
# 4) index is a constant - example string, integer
|
|
if isinstance(self.index, Source):
|
|
if not isinstance(self.index, ConstDictKeySource):
|
|
raise ValueError(
|
|
"GetItemSource index must be a constant, enum or ConstDictKeySource"
|
|
)
|
|
return f"{self.base.name()}[{self.index.name()}]"
|
|
elif self.index_is_slice:
|
|
return f"{self.base.name()}[{self.unpack_slice()!r}]"
|
|
elif isinstance(self.index, enum.Enum):
|
|
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
|
|
else:
|
|
return f"{self.base.name()}[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstDictKeySource(GetItemSource):
|
|
def is_dict_key(self):
|
|
return True
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
*codegen.create_load_import_from(utils.__name__, "dict_keys_getitem"),
|
|
*self.base.reconstruct(codegen),
|
|
codegen.create_load_const(self.index),
|
|
*create_call_function(2, True),
|
|
]
|
|
|
|
def name(self):
|
|
return f"___dict_keys_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TupleIteratorGetItemSource(GetItemSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
|
return [
|
|
*self.base.reconstruct(codegen),
|
|
codegen.create_load_const(self.index),
|
|
*create_call_function(2, True),
|
|
]
|
|
|
|
def name(self):
|
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TypeSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("builtins", "type")
|
|
return self.base.reconstruct(codegen) + create_call_function(1, True)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"type({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ODictGetItemSource(ChainedSource):
|
|
index: Any
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen._create_load_const(collections.OrderedDict.__getitem__),
|
|
*reconstruct_getitem(self, codegen, index_is_slice=False),
|
|
*create_call_function(2, True),
|
|
]
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if isinstance(self.index, type):
|
|
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
|
|
return f"___odict_getitem({self.base.name()}, {rep})"
|
|
elif isinstance(self.index, Source):
|
|
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
|
|
else:
|
|
return f"___odict_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NNModuleSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
return self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
def name(self):
|
|
return self.base.name()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NotNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FSDPNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalStateSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstantSource(Source):
|
|
source_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.source_name, False, add=False)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.CONSTANT
|
|
|
|
def name(self):
|
|
return self.source_name
|
|
|
|
def make_guard(self, fn):
|
|
raise NotImplementedError()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NumpyTensorSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"___from_numpy({self.base.name()})"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("torch", "as_tensor")
|
|
return self.base.reconstruct(codegen) + create_call_function(1, True)
|
|
|
|
|
|
# This is a synthetic source that is associated with the singleton
|
|
# shape env guard we always register for all frames. We get the actual
|
|
# guard contents from the ambient ShapeEnv
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ShapeEnvSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SHAPE_ENV
|
|
|
|
|
|
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_local_source(
|
|
source.base, allow_cell_or_freevar=allow_cell_or_freevar
|
|
)
|
|
if not isinstance(source, LocalSource):
|
|
return False
|
|
if not allow_cell_or_freevar and source.cell_or_freevar:
|
|
return False
|
|
return True
|
|
|
|
|
|
# TODO: can probably write a generic "test this on everything in the chain"
|
|
# helper
|
|
def is_from_defaults(source: Source):
|
|
if isinstance(source, DefaultsSource):
|
|
return True
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_defaults(source.base)
|
|
return False
|