pytorch/torch/_dynamo/source.py
lezcano eb2bdfae88 Make variables in dict LazyTrackers (not lazily guarded yet) and avoid using DICT_KEYS guard (#117625)
Make variables in dict lazy and remove DICT_KEYS guard.

We build the keys of a dict depth-first and we rely on the guards of
each element in the dict to create the correct guards. This allows us to
remove the rather buggy DICT_KEYS guard and make the guard lazy.
The guards are not completely lazy yet, as we instantiate them in
`_HashableTracker._eq_impl` but it should be possible to make them
truly lazy.

Also, adding new types to the supported types within keys should be less
error prone.

This is marginally less efficient when we graph break, but in turn we
should graph break much less. It also  makes the dicts code easier to maintain
(removes `is_hashable_python_var`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117625
Approved by: https://github.com/jansel, https://github.com/peterbell10, https://github.com/anijain2305
ghstack dependencies: #117982, #118098, #117983
2024-02-02 14:38:08 +00:00

516 lines
15 KiB
Python

import collections
import dataclasses
import enum
from typing import Any, Optional, Union
from torch._guards import ChainedSource, GuardSource, Source
from . import utils
from .bytecode_transformation import create_call_function, create_instruction
from .utils import enum_repr
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
# so those cases are omitted intentionally
_GUARD_SOURCE_NN_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
}
_GUARD_SOURCE_FSDP_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
}
_GUARD_SOURCE_NOT_NN_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL,
GuardSource.GLOBAL: GuardSource.GLOBAL,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
}
def is_constant_source(source):
if isinstance(source, ConstantSource):
return True
try:
if source.guard_source() == GuardSource.CONSTANT:
return True
except NotImplementedError:
pass
return False
def is_input_source(source):
return source.guard_source() in [
GuardSource.LOCAL,
GuardSource.GLOBAL,
GuardSource.LOCAL_NN_MODULE,
GuardSource.GLOBAL_NN_MODULE,
GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL_FSDP_MODULE,
]
def reconstruct_getitem(
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
):
instrs = source.base.reconstruct(codegen)
if isinstance(source.index, Source):
instrs.extend(source.index.reconstruct(codegen))
else:
if index_is_slice:
assert isinstance(source, GetItemSource)
instrs.append(codegen.create_load_const(source.unpack_slice()))
else:
instrs.append(codegen.create_load_const(source.index))
return instrs
@dataclasses.dataclass(frozen=True)
class LocalSource(Source):
local_name: str
cell_or_freevar: bool = False
def reconstruct(self, codegen):
return [codegen.create_load(self.local_name)]
def guard_source(self):
return GuardSource.LOCAL
def name(self):
return f"L[{repr(self.local_name)}]"
@dataclasses.dataclass(frozen=True)
class RandomValueSource(Source):
random_call_index: int
def guard_source(self):
return GuardSource.RANDOM_VALUE
def reconstruct(self, codegen):
return [
codegen.create_load(codegen.tx.output.random_values_var),
codegen.create_load_const(self.random_call_index),
create_instruction("BINARY_SUBSCR"),
]
def name(self):
return f"random_value_{self.random_call_index}"
@dataclasses.dataclass(frozen=True)
class GlobalSource(Source):
global_name: str
def reconstruct(self, codegen):
return [codegen.create_load_global(self.global_name, False, add=True)]
def guard_source(self):
return GuardSource.GLOBAL
def name(self):
return f"G[{repr(self.global_name)}]"
@dataclasses.dataclass(frozen=True)
class GlobalWeakRefSource(Source):
global_name: str
def reconstruct(self, codegen):
return [
codegen.create_load_global(self.global_name, True, add=True),
*create_call_function(0, False),
]
def guard_source(self):
return GuardSource.GLOBAL
def name(self):
return f"G[{repr(self.global_name)}]()"
@dataclasses.dataclass(frozen=True)
class AttrSource(ChainedSource):
member: str
def __post_init__(self):
assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member:
member_parts = self.member.split(".")
object.__setattr__(
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
)
object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen):
return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
def guard_source(self):
return self.base.guard_source()
def name(self):
if not self.member.isidentifier():
return f"getattr({self.base.name()}, {self.member!r})"
return f"{self.base.name()}.{self.member}"
@dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource):
def guard_source(self):
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
class TensorProperty(enum.Enum):
SIZE = 0
STRIDE = 1
STORAGE_OFFSET = 2
def method_name(self):
if self is TensorProperty.SIZE:
return "size"
elif self is TensorProperty.STRIDE:
return "stride"
elif self is TensorProperty.STORAGE_OFFSET:
return "storage_offset"
@dataclasses.dataclass(frozen=True)
class TensorPropertySource(ChainedSource):
prop: TensorProperty
idx: Optional[int] = None # None for STORAGE_OFFSET
def __post_init__(self):
assert self.base is not None
if self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
else:
assert self.idx is not None
def reconstruct(self, codegen):
instructions = [
*self.base.reconstruct(codegen),
codegen.create_load_attr(self.prop.method_name()),
]
if self.idx is not None:
instructions.append(codegen.create_load_const(self.idx))
instructions.extend(
create_call_function(1 if self.idx is not None else 0, True)
)
return instructions
def guard_source(self):
return self.base.guard_source()
def name(self):
if self.prop is TensorProperty.SIZE:
return f"{self.base.name()}.size()[{self.idx}]"
elif self.prop is TensorProperty.STRIDE:
return f"{self.base.name()}.stride()[{self.idx}]"
elif self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
return f"{self.base.name()}.storage_offset()"
else:
raise AssertionError(f"unhandled {self.prop}")
@dataclasses.dataclass(frozen=True)
class NegateSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
raise NotImplementedError()
def guard_source(self):
return self.base.guard_source()
def name(self):
# NB: use method call so that function stripping regexes work
return f"{self.base.name()}.__neg__()"
@dataclasses.dataclass(frozen=True)
class ConvertIntSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
return self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
is_kw: bool = False
field: str = dataclasses.field(init=False, repr=False, compare=False)
_name: str = dataclasses.field(init=False, repr=False, compare=False)
def __post_init__(self):
assert (
self.base
), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
if self.is_kw:
assert isinstance(self.idx_key, str)
object.__setattr__(self, "field", "__kwdefaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
)
else:
assert isinstance(self.idx_key, int)
object.__setattr__(self, "field", "__defaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
)
def reconstruct(self, codegen):
instrs = self.base.reconstruct(codegen)
instrs.extend(codegen.create_load_attrs(self.field))
instrs.extend(
[
codegen.create_load_const(self.idx_key),
create_instruction("BINARY_SUBSCR"),
]
)
return instrs
def guard_source(self):
return self.base.guard_source()
def name(self):
return self._name
@dataclasses.dataclass(frozen=True)
class GetItemSource(ChainedSource):
index: Any
index_is_slice: bool = False
def __post_init__(self):
assert self.base is not None
if isinstance(self.index, slice):
# store the hashable version of the slice so the whole GetItemSource is hashable
super().__setattr__("index", self.index.__reduce__())
super().__setattr__("index_is_slice", True)
def reconstruct(self, codegen):
return [
*reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice),
create_instruction("BINARY_SUBSCR"),
]
def guard_source(self):
return self.base.guard_source()
def unpack_slice(self):
assert self.index_is_slice
slice_class, slice_args = self.index
return slice_class(*slice_args)
def name(self):
# Index can be of following types
# 1) ConstDictKeySource
# 2) enum.Enum
# 3) index is a slice - example 1:4
# 4) index is a constant - example string, integer
if isinstance(self.index, Source):
if not isinstance(self.index, ConstDictKeySource):
raise ValueError(
"GetItemSource index must be a constant, enum or ConstDictKeySource"
)
return f"{self.base.name()}[{self.index.name()}]"
elif self.index_is_slice:
return f"{self.base.name()}[{self.unpack_slice()!r}]"
elif isinstance(self.index, enum.Enum):
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
else:
return f"{self.base.name()}[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
class ConstDictKeySource(GetItemSource):
def is_dict_key(self):
return True
def reconstruct(self, codegen):
return [
*codegen.create_load_import_from(utils.__name__, "dict_keys_getitem"),
*self.base.reconstruct(codegen),
codegen.create_load_const(self.index),
*create_call_function(2, True),
]
def name(self):
return f"___dict_keys_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource):
def reconstruct(self, codegen):
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
return [
*self.base.reconstruct(codegen),
codegen.create_load_const(self.index),
*create_call_function(2, True),
]
def name(self):
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class TypeSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
codegen.load_import_from("builtins", "type")
return self.base.reconstruct(codegen) + create_call_function(1, True)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"type({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class ODictGetItemSource(ChainedSource):
index: Any
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
return [
codegen._create_load_const(collections.OrderedDict.__getitem__),
*reconstruct_getitem(self, codegen, index_is_slice=False),
*create_call_function(2, True),
]
def guard_source(self):
return self.base.guard_source()
def name(self):
if isinstance(self.index, type):
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
return f"___odict_getitem({self.base.name()}, {rep})"
elif isinstance(self.index, Source):
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
else:
return f"___odict_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource):
def reconstruct(self, codegen):
return self.base.reconstruct(codegen)
def guard_source(self):
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
def name(self):
return self.base.name()
@dataclasses.dataclass(frozen=True)
class NotNNModuleSource(NNModuleSource):
def guard_source(self):
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
def guard_source(self):
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.GLOBAL
@dataclasses.dataclass(frozen=True)
class ConstantSource(Source):
source_name: str
def reconstruct(self, codegen):
return [codegen.create_load_global(self.source_name, False, add=False)]
def guard_source(self):
return GuardSource.CONSTANT
def name(self):
return self.source_name
def make_guard(self, fn):
raise NotImplementedError()
@dataclasses.dataclass(frozen=True)
class NumpyTensorSource(ChainedSource):
def name(self) -> str:
return f"___from_numpy({self.base.name()})"
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen):
codegen.load_import_from("torch", "as_tensor")
return self.base.reconstruct(codegen) + create_call_function(1, True)
# This is a synthetic source that is associated with the singleton
# shape env guard we always register for all frames. We get the actual
# guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.SHAPE_ENV
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
if isinstance(source, ChainedSource):
return is_from_local_source(
source.base, allow_cell_or_freevar=allow_cell_or_freevar
)
if not isinstance(source, LocalSource):
return False
if not allow_cell_or_freevar and source.cell_or_freevar:
return False
return True
# TODO: can probably write a generic "test this on everything in the chain"
# helper
def is_from_defaults(source: Source):
if isinstance(source, DefaultsSource):
return True
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False