mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
I'm going to need this in the follow up PR. Instead of storing only Source.name() in Symbol, I now store a full on Source. Lots of replumbing reoccurs. In particular: - Move Source to torch._guards to break cycles - I have to add TensorPropertySource and NegateSource to handle x.size()[0] and -x codegen that I was doing with string manipulation previously - I tighten up invariants so that I never pass source=None; instead I pass ConstantSource (these are constant sources right) and test for that rather than source being missing. I think this is more parsimonious - Some mypy wobbles from new imports I didn't move LocalSource and friends to torch._guards, but I ended up needing to access them in a few places. The main annoyance with moving these is that then I also need to move the bytecode codegen stuff, and that's not so easy to move without bringing in the kitchen sink. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/91057 Approved by: https://github.com/albanD, https://github.com/voznesenskym
303 lines
7.7 KiB
Python
303 lines
7.7 KiB
Python
import collections
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Optional
|
|
|
|
from torch._guards import GuardSource, Source
|
|
|
|
from . import utils
|
|
from .bytecode_transformation import create_instruction
|
|
from .utils import rename_implicit
|
|
|
|
_GUARD_SOURCE_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_NOT_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
|
|
}
|
|
|
|
|
|
def is_constant_source(source):
|
|
if isinstance(source, ConstantSource):
|
|
return True
|
|
try:
|
|
if source.guard_source() == GuardSource.CONSTANT:
|
|
return True
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LocalSource(Source):
|
|
local_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load(self.local_name)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.LOCAL
|
|
|
|
def name(self):
|
|
return rename_implicit(self.local_name)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RandomValueSource(Source):
|
|
random_call_index: int
|
|
|
|
def guard_source(self):
|
|
return GuardSource.RANDOM_VALUE
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load(codegen.tx.output.random_values_var),
|
|
codegen.create_load_const(self.random_call_index),
|
|
create_instruction("BINARY_SUBSCR"),
|
|
]
|
|
|
|
def name(self):
|
|
return rename_implicit(f"random_value_{self.random_call_index}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GlobalSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.global_name, add=True)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return self.global_name
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GlobalWeakRefSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load_global(self.global_name, add=True),
|
|
create_instruction("CALL_FUNCTION", 0),
|
|
]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"{self.global_name}()"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AttrSource(Source):
|
|
base: Source
|
|
member: str
|
|
|
|
def __init__(self, base, member):
|
|
super().__init__()
|
|
if "." in member:
|
|
member_parts = member.split(".")
|
|
self.base = AttrSource(base, ".".join(member_parts[:-1]))
|
|
self.member = member_parts[-1]
|
|
else:
|
|
self.base = base
|
|
self.member = member
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.member.isnumeric():
|
|
return f"getattr({self.base.name()}, {self.member!r})"
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
class TensorProperty(enum.Enum):
|
|
SIZE = 0
|
|
STRIDE = 1
|
|
STORAGE_OFFSET = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorPropertySource(Source):
|
|
base: Source
|
|
prop: TensorProperty
|
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
else:
|
|
assert self.idx is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.prop is TensorProperty.SIZE:
|
|
return f"{self.base.name()}.size()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STRIDE:
|
|
return f"{self.base.name()}.stride()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
return f"{self.base.name()}.storage_offset()"
|
|
else:
|
|
raise AssertionError(f"unhandled {self.prop}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NegateSource(Source):
|
|
base: Source
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
# NB: use method call so that function stripping regexes work
|
|
return f"{self.base.name()}.__neg__()"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GetItemSource(Source):
|
|
base: Source
|
|
index: Any
|
|
|
|
def reconstruct(self, codegen):
|
|
instrs = self.base.reconstruct(codegen)
|
|
|
|
if isinstance(self.index, Source):
|
|
instrs.extend(self.index.reconstruct(codegen))
|
|
else:
|
|
instrs.append(codegen.create_load_const(self.index))
|
|
instrs.append(create_instruction("BINARY_SUBSCR"))
|
|
|
|
return instrs
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if isinstance(self.index, Source):
|
|
return f"{self.base.name()}[{self.index.name()}]"
|
|
else:
|
|
return f"{self.base.name()}[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TupleIteratorGetItemSource(GetItemSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
|
return self.base.reconstruct(codegen) + [
|
|
codegen.create_load_const(self.index),
|
|
create_instruction("CALL_FUNCTION", 2),
|
|
]
|
|
|
|
def name(self):
|
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TypeSource(Source):
|
|
base: Source
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("builtins", "type")
|
|
return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"type({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ODictGetItemSource(Source):
|
|
base: Source
|
|
index: Any
|
|
|
|
def reconstruct(self, codegen):
|
|
return (
|
|
[codegen._create_load_const(collections.OrderedDict.__getitem__)]
|
|
+ self.base.reconstruct(codegen)
|
|
+ [
|
|
codegen.create_load_const(self.index),
|
|
create_instruction("CALL_FUNCTION", 2),
|
|
]
|
|
)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"___odict_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NNModuleSource(Source):
|
|
inner: Source
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.inner.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
|
|
|
|
def name(self):
|
|
return self.inner.name()
|
|
|
|
|
|
class NotNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ConstantSource(Source):
|
|
source_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.source_name, add=False)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.CONSTANT
|
|
|
|
def name(self):
|
|
return self.source_name
|
|
|
|
def make_guard(self, fn, is_volatile=False):
|
|
raise NotImplementedError()
|
|
|
|
|
|
# This is a synthetic source that is associated with the singleton
|
|
# shape env guard we always register for all frames. We get the actual
|
|
# guard contents from the ambient ShapeEnv
|
|
@dataclasses.dataclass
|
|
class ShapeEnvSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SHAPE_ENV
|