pytorch/torch/_dynamo/source.py
Edward Z. Yang 2ba102f689 Implement native support for float inputs in Dynamo and ShapeEnv (#125325)
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.

The generated graph looks like this for the test `test_unspec_float_output`:

```
 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_

     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)
```

The ingredients:

* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with.  Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-05-14 04:10:01 +00:00

631 lines
19 KiB
Python

import collections
import dataclasses
import enum
from typing import Any, Optional, Union
from torch._guards import ChainedSource, GuardSource, Source
from . import utils
from .bytecode_transformation import create_call_function, create_instruction
from .utils import enum_repr
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
# so those cases are omitted intentionally
_GUARD_SOURCE_NN_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
}
_GUARD_SOURCE_FSDP_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
}
_GUARD_SOURCE_NOT_NN_MODULE = {
GuardSource.LOCAL: GuardSource.LOCAL,
GuardSource.GLOBAL: GuardSource.GLOBAL,
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
}
def is_constant_source(source):
if isinstance(source, ConstantSource):
return True
try:
if source.guard_source() == GuardSource.CONSTANT:
return True
except NotImplementedError:
pass
return False
def reconstruct_getitem(
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
):
source.base.reconstruct(codegen)
if isinstance(source.index, Source):
source.index.reconstruct(codegen)
else:
if index_is_slice:
assert isinstance(source, GetItemSource)
codegen.append_output(codegen.create_load_const(source.unpack_slice()))
else:
codegen.append_output(codegen.create_load_const(source.index))
@dataclasses.dataclass(frozen=True)
class LocalSource(Source):
local_name: str
cell_or_freevar: bool = False
def reconstruct(self, codegen):
codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self):
return GuardSource.LOCAL
def name(self):
return f"L[{repr(self.local_name)}]"
@dataclasses.dataclass(frozen=True)
class SyntheticLocalSource(Source):
local_name: str
def reconstruct(self, codegen):
codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self):
return GuardSource.SYNTHETIC_LOCAL
def name(self):
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
@dataclasses.dataclass(frozen=True)
class RandomValueSource(Source):
random_call_index: int
def guard_source(self):
return GuardSource.RANDOM_VALUE
def reconstruct(self, codegen):
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
codegen.append_output(codegen.create_load_const(self.random_call_index))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def name(self):
return f"random_value_{self.random_call_index}"
@dataclasses.dataclass(frozen=True)
class GlobalSource(Source):
global_name: str
def reconstruct(self, codegen):
codegen.append_output(
codegen.create_load_global(self.global_name, False, add=True)
)
def guard_source(self):
return GuardSource.GLOBAL
def name(self):
return f"G[{repr(self.global_name)}]"
@dataclasses.dataclass(frozen=True)
class GlobalWeakRefSource(Source):
global_name: str
def reconstruct(self, codegen):
codegen.append_output(
codegen.create_load_global(self.global_name, True, add=True)
)
codegen.extend_output(create_call_function(0, False))
def guard_source(self):
return GuardSource.GLOBAL
def name(self):
return f"G[{repr(self.global_name)}]()"
@dataclasses.dataclass(frozen=True)
class AttrSource(ChainedSource):
member: str
def __post_init__(self):
assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member:
member_parts = self.member.split(".")
object.__setattr__(
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
)
object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
return self.base.guard_source()
def name(self):
if not self.member.isidentifier():
return f"getattr({self.base.name()}, {self.member!r})"
return f"{self.base.name()}.{self.member}"
# Represents tensor.grad source. It could be represented by AttrSource as well.
# But, we could access grad field on tensor directly in C++ without going
# through the Python bytecodes. Therefore, we use a separate source for grad
# field.
@dataclasses.dataclass(frozen=True)
class GradSource(ChainedSource):
member: str = "grad"
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.{self.member}"
@dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource):
def guard_source(self):
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
# This source is intended to be used in places where a source is needed but it is expected
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
# source. Guarding on this source is an error.
#
# Example: During subclass view fake-ification, any close-over ViewFunc state should be
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
# is useful for symbols utilized in the middle of the view chain that are not expected to be
# present within the final view shape metadata.
@dataclasses.dataclass(frozen=True)
class EphemeralSource(Source):
desc: Optional[str] = None
def guard_source(self):
return GuardSource.EPHEMERAL
def name(self):
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
def make_guard(self):
raise NotImplementedError
def is_ephemeral(self):
return True
class TensorProperty(enum.Enum):
SIZE = 0
STRIDE = 1
STORAGE_OFFSET = 2
def method_name(self):
if self is TensorProperty.SIZE:
return "size"
elif self is TensorProperty.STRIDE:
return "stride"
elif self is TensorProperty.STORAGE_OFFSET:
return "storage_offset"
@dataclasses.dataclass(frozen=True)
class TensorPropertySource(ChainedSource):
prop: TensorProperty
idx: Optional[int] = None # None for STORAGE_OFFSET
def __post_init__(self):
assert self.base is not None
if self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
else:
assert self.idx is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
codegen.append_output(codegen.create_load_attr(self.prop.method_name()))
if self.idx is not None:
codegen.append_output(codegen.create_load_const(self.idx))
codegen.extend_output(
create_call_function(1 if self.idx is not None else 0, True)
)
def guard_source(self):
return self.base.guard_source()
def name(self):
if self.prop is TensorProperty.SIZE:
return f"{self.base.name()}.size()[{self.idx}]"
elif self.prop is TensorProperty.STRIDE:
return f"{self.base.name()}.stride()[{self.idx}]"
elif self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None
return f"{self.base.name()}.storage_offset()"
else:
raise AssertionError(f"unhandled {self.prop}")
@dataclasses.dataclass(frozen=True)
class NegateSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
raise NotImplementedError
def guard_source(self):
return self.base.guard_source()
def name(self):
# NB: use method call so that function stripping regexes work
return f"{self.base.name()}.__neg__()"
@dataclasses.dataclass(frozen=True)
class ConvertIntSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True)
class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}._type().qualified_name()"
@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
is_kw: bool = False
field: str = dataclasses.field(init=False, repr=False, compare=False)
_name: str = dataclasses.field(init=False, repr=False, compare=False)
def __post_init__(self):
assert (
self.base
), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
if self.is_kw:
assert isinstance(self.idx_key, str)
object.__setattr__(self, "field", "__kwdefaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
)
else:
assert isinstance(self.idx_key, int)
object.__setattr__(self, "field", "__defaults__")
object.__setattr__(
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
)
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
codegen.extend_output(codegen.create_load_attrs(self.field))
codegen.append_output(codegen.create_load_const(self.idx_key))
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self):
return self.base.guard_source()
def name(self):
return self._name
@dataclasses.dataclass(frozen=True)
class GetItemSource(ChainedSource):
index: Any
index_is_slice: bool = False
def __post_init__(self):
assert self.base is not None
if isinstance(self.index, slice):
# store the hashable version of the slice so the whole GetItemSource is hashable
super().__setattr__("index", self.index.__reduce__())
super().__setattr__("index_is_slice", True)
def reconstruct(self, codegen):
reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice)
codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self):
return self.base.guard_source()
def unpack_slice(self):
assert self.index_is_slice
slice_class, slice_args = self.index
return slice_class(*slice_args)
def name(self):
# Index can be of following types
# 1) ConstDictKeySource
# 2) enum.Enum
# 3) index is a slice - example 1:4
# 4) index is a constant - example string, integer
if isinstance(self.index, Source):
if not isinstance(self.index, ConstDictKeySource):
raise ValueError(
"GetItemSource index must be a constant, enum or ConstDictKeySource"
)
return f"{self.base.name()}[{self.index.name()}]"
elif self.index_is_slice:
return f"{self.base.name()}[{self.unpack_slice()!r}]"
elif isinstance(self.index, enum.Enum):
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
else:
return f"{self.base.name()}[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
class ConstDictKeySource(GetItemSource):
def is_dict_key(self):
return True
def reconstruct(self, codegen):
codegen.load_import_from(utils.__name__, "dict_keys_getitem")
self.base.reconstruct(codegen)
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, True))
def name(self):
# The list creation will be CSE'd by PyExprCSEPass
return f"list({self.base.name()}.keys())[{self.index!r}]"
@dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource):
def reconstruct(self, codegen):
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
self.base.reconstruct(codegen)
codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, True))
def name(self):
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class TypeSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
codegen.load_import_from("builtins", "type")
self.base.reconstruct(codegen)
codegen.extend_output(create_call_function(1, True))
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"type({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class ODictGetItemSource(ChainedSource):
index: Any
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
codegen.append_output(
codegen._create_load_const(collections.OrderedDict.__getitem__)
)
reconstruct_getitem(self, codegen, index_is_slice=False)
codegen.extend_output(create_call_function(2, True))
def guard_source(self):
return self.base.guard_source()
def name(self):
if isinstance(self.index, type):
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
return f"___odict_getitem({self.base.name()}, {rep})"
elif isinstance(self.index, Source):
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
else:
return f"___odict_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True)
class OptimizerSource(ChainedSource):
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return self.base.name()
@dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource):
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
def name(self):
return self.base.name()
@dataclasses.dataclass(frozen=True)
class NotNNModuleSource(NNModuleSource):
def guard_source(self):
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource):
def guard_source(self):
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.GLOBAL
@dataclasses.dataclass(frozen=True)
class ConstantSource(Source):
source_name: str
def reconstruct(self, codegen):
codegen.append_output(
codegen.create_load_global(self.source_name, False, add=False)
)
def guard_source(self):
return GuardSource.CONSTANT
def name(self):
return self.source_name
def make_guard(self, fn):
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class NumpyTensorSource(ChainedSource):
def name(self) -> str:
return f"___from_numpy({self.base.name()})"
def guard_source(self):
return self.base.guard_source()
def reconstruct(self, codegen):
codegen.load_import_from("torch", "as_tensor")
self.base.reconstruct(codegen)
codegen.extend_output(create_call_function(1, True))
# NB: We don't expect you to actually ever generate guards against this
# source, it is ephemeral
@dataclasses.dataclass(frozen=True)
class FloatTensorSource(ChainedSource):
def name(self) -> str:
return f"___as_tensor({self.base.name()})"
def guard_source(self):
return self.base.guard_source()
# This is a synthetic source that is associated with the singleton
# shape env guard we always register for all frames. We get the actual
# guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source):
def name(self):
return ""
def guard_source(self):
return GuardSource.BACKWARD_STATE
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
if isinstance(source, ChainedSource):
return is_from_local_source(
source.base, allow_cell_or_freevar=allow_cell_or_freevar
)
if not isinstance(source, LocalSource):
return False
if not allow_cell_or_freevar and source.cell_or_freevar:
return False
return True
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True
elif isinstance(source, ChainedSource):
return is_from_flatten_script_object_source(source.base)
return False
def is_from_optimizer_source(source: Source):
if isinstance(source, OptimizerSource):
return True
if isinstance(source, ChainedSource):
return is_from_optimizer_source(source.base)
return False
# TODO: can probably write a generic "test this on everything in the chain"
# helper
def is_from_defaults(source: Source):
if isinstance(source, DefaultsSource):
return True
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False