mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
We should not allow creating a derived source (e.g. AttrSource), without a valid base source. It's more reliable to check this in the source `__init__` or `__post_init__` than asserting we have a valid source before passing that to an AttrSource() call. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91711 Approved by: https://github.com/voznesenskym
325 lines
8.2 KiB
Python
325 lines
8.2 KiB
Python
import collections
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Optional
|
|
|
|
from torch._guards import GuardSource, Source
|
|
|
|
from . import utils
|
|
from .bytecode_transformation import create_instruction
|
|
from .utils import rename_implicit
|
|
|
|
_GUARD_SOURCE_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_NOT_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
|
|
}
|
|
|
|
|
|
def is_constant_source(source):
|
|
if isinstance(source, ConstantSource):
|
|
return True
|
|
try:
|
|
if source.guard_source() == GuardSource.CONSTANT:
|
|
return True
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def is_input_source(source):
|
|
return source.guard_source() in [
|
|
GuardSource.LOCAL,
|
|
GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE,
|
|
]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LocalSource(Source):
|
|
local_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load(self.local_name)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.LOCAL
|
|
|
|
def name(self):
|
|
return rename_implicit(self.local_name)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RandomValueSource(Source):
|
|
random_call_index: int
|
|
|
|
def guard_source(self):
|
|
return GuardSource.RANDOM_VALUE
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load(codegen.tx.output.random_values_var),
|
|
codegen.create_load_const(self.random_call_index),
|
|
create_instruction("BINARY_SUBSCR"),
|
|
]
|
|
|
|
def name(self):
|
|
return rename_implicit(f"random_value_{self.random_call_index}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GlobalSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.global_name, add=True)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return self.global_name
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GlobalWeakRefSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [
|
|
codegen.create_load_global(self.global_name, add=True),
|
|
create_instruction("CALL_FUNCTION", 0),
|
|
]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"{self.global_name}()"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AttrSource(Source):
|
|
base: Source
|
|
member: str
|
|
|
|
def __init__(self, base, member):
|
|
super().__init__()
|
|
if "." in member:
|
|
member_parts = member.split(".")
|
|
self.base = AttrSource(base, ".".join(member_parts[:-1]))
|
|
self.member = member_parts[-1]
|
|
else:
|
|
self.base = base
|
|
self.member = member
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.member.isnumeric():
|
|
return f"getattr({self.base.name()}, {self.member!r})"
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
class TensorProperty(enum.Enum):
|
|
SIZE = 0
|
|
STRIDE = 1
|
|
STORAGE_OFFSET = 2
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorPropertySource(Source):
|
|
base: Source
|
|
prop: TensorProperty
|
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
else:
|
|
assert self.idx is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.prop is TensorProperty.SIZE:
|
|
return f"{self.base.name()}.size()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STRIDE:
|
|
return f"{self.base.name()}.stride()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
return f"{self.base.name()}.storage_offset()"
|
|
else:
|
|
raise AssertionError(f"unhandled {self.prop}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NegateSource(Source):
|
|
base: Source
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError()
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
# NB: use method call so that function stripping regexes work
|
|
return f"{self.base.name()}.__neg__()"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GetItemSource(Source):
|
|
base: Source
|
|
index: Any
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
instrs = self.base.reconstruct(codegen)
|
|
|
|
if isinstance(self.index, Source):
|
|
instrs.extend(self.index.reconstruct(codegen))
|
|
else:
|
|
instrs.append(codegen.create_load_const(self.index))
|
|
instrs.append(create_instruction("BINARY_SUBSCR"))
|
|
|
|
return instrs
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if isinstance(self.index, Source):
|
|
return f"{self.base.name()}[{self.index.name()}]"
|
|
else:
|
|
return f"{self.base.name()}[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TupleIteratorGetItemSource(GetItemSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
|
return self.base.reconstruct(codegen) + [
|
|
codegen.create_load_const(self.index),
|
|
create_instruction("CALL_FUNCTION", 2),
|
|
]
|
|
|
|
def name(self):
|
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TypeSource(Source):
|
|
base: Source
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("builtins", "type")
|
|
return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"type({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ODictGetItemSource(Source):
|
|
base: Source
|
|
index: Any
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
return (
|
|
[codegen._create_load_const(collections.OrderedDict.__getitem__)]
|
|
+ self.base.reconstruct(codegen)
|
|
+ [
|
|
codegen.create_load_const(self.index),
|
|
create_instruction("CALL_FUNCTION", 2),
|
|
]
|
|
)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"___odict_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NNModuleSource(Source):
|
|
inner: Source
|
|
|
|
def reconstruct(self, codegen):
|
|
return self.inner.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
|
|
|
|
def name(self):
|
|
return self.inner.name()
|
|
|
|
|
|
class NotNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ConstantSource(Source):
|
|
source_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
return [codegen.create_load_global(self.source_name, add=False)]
|
|
|
|
def guard_source(self):
|
|
return GuardSource.CONSTANT
|
|
|
|
def name(self):
|
|
return self.source_name
|
|
|
|
def make_guard(self, fn, is_volatile=False):
|
|
raise NotImplementedError()
|
|
|
|
|
|
# This is a synthetic source that is associated with the singleton
|
|
# shape env guard we always register for all frames. We get the actual
|
|
# guard contents from the ambient ShapeEnv
|
|
@dataclasses.dataclass
|
|
class ShapeEnvSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SHAPE_ENV
|