mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
437 lines
16 KiB
Python
437 lines
16 KiB
Python
import collections
|
|
import dataclasses
|
|
import re
|
|
import sys
|
|
import types
|
|
from typing import Counter, Dict, List, Optional
|
|
|
|
import torch.nn
|
|
from . import utils
|
|
|
|
from .bytecode_transformation import (
|
|
create_call_function,
|
|
create_call_method,
|
|
create_dup_top,
|
|
create_instruction,
|
|
create_load_attr,
|
|
create_load_global,
|
|
create_load_method,
|
|
create_rot_n,
|
|
Instruction,
|
|
)
|
|
from .exc import unimplemented
|
|
from .source import AttrSource, Source
|
|
from .utils import is_safe_constant, rot_n_helper
|
|
from .variables.base import VariableTracker
|
|
from .variables.nn_module import NNModuleVariable
|
|
from .variables.tensor import (
|
|
NumpyNdarrayVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
UnspecializedPythonVariable,
|
|
)
|
|
from .variables.torch_function import TensorWithTFOverrideVariable
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphOutputEntry:
|
|
index: int
|
|
variable: VariableTracker
|
|
|
|
|
|
class PyCodegen:
|
|
"""
|
|
Helper class uses for constructing Python bytecode
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tx=None,
|
|
root: Optional[torch.nn.Module] = None,
|
|
graph_output_var: Optional[str] = None,
|
|
tempvars=None,
|
|
):
|
|
self.root = root
|
|
self.top_of_stack: Optional[VariableTracker] = None
|
|
self.uses: Counter[VariableTracker] = collections.Counter()
|
|
self.graph_outputs: Dict[int, GraphOutputEntry] = {}
|
|
self._output: List[Instruction] = []
|
|
self.tempvars = tempvars or {}
|
|
self.tx = tx
|
|
self.graph_output_var = graph_output_var
|
|
self.code_options = self.tx.output.code_options
|
|
self.cell_and_freevars = self.tx.cell_and_freevars
|
|
self.new_var = self.tx.output.new_var
|
|
self.mutable_side_effects_from_source = False
|
|
self.value_from_source: bool = True
|
|
|
|
def restore_stack(self, stack_values, *, value_from_source=True):
|
|
prior = self.mutable_side_effects_from_source
|
|
self.mutable_side_effects_from_source = True
|
|
prev = self.value_from_source
|
|
self.value_from_source &= value_from_source
|
|
try:
|
|
self.foreach(stack_values)
|
|
finally:
|
|
self.mutable_side_effects_from_source = prior
|
|
self.value_from_source = prev
|
|
|
|
def graph_output_vars(self):
|
|
return [x.variable for x in self.graph_outputs.values()]
|
|
|
|
def call_reconstruct(self, value):
|
|
res = value.reconstruct(self)
|
|
assert res is None, f"reconstruct!=None {value}"
|
|
|
|
def __call__(self, value, allow_cache=True):
|
|
"""Generate code such that top-of-stack (TOS) is set to value"""
|
|
if isinstance(value, Source):
|
|
self.call_reconstruct(value)
|
|
self.clear_tos()
|
|
return
|
|
|
|
assert isinstance(value, VariableTracker)
|
|
output = self._output
|
|
graph_outputs = self.graph_outputs
|
|
|
|
if self.top_of_stack is value and allow_cache:
|
|
output.append(create_dup_top())
|
|
return
|
|
|
|
if self.mutable_side_effects_from_source:
|
|
# this is needed to get aliasing relationships right
|
|
# value.mutable_local.source will get mutated to hold `value`
|
|
# mutable_side_effects_from_source=False is used to codegen the mutation
|
|
# mutable_side_effects_from_source=True is used to codegen a reference
|
|
from .side_effects import MutableSideEffects
|
|
|
|
if isinstance(value.mutable_local, MutableSideEffects):
|
|
self(value.mutable_local.source)
|
|
return
|
|
|
|
if allow_cache:
|
|
if value.mutable_local and value.mutable_local in self.tempvars:
|
|
output.append(self.create_load(self.tempvars[value.mutable_local]))
|
|
self.top_of_stack = value
|
|
return
|
|
if self.tempvars.get(value) is not None:
|
|
output.append(self.create_load(self.tempvars[value]))
|
|
self.top_of_stack = value
|
|
return
|
|
|
|
if value.source is not None and allow_cache and self.value_from_source:
|
|
self.call_reconstruct(value.source)
|
|
elif value.is_python_constant() and is_safe_constant(
|
|
value.as_python_constant()
|
|
):
|
|
output.append(self.create_load_const(value.as_python_constant()))
|
|
elif isinstance(value, TensorWithTFOverrideVariable):
|
|
graph_outputs_key = self.add_graph_output(value)
|
|
|
|
self.load_import_from(utils.__name__, "to_subclass")
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.append(
|
|
self.create_load_global(
|
|
value.global_mangled_class_name(self.tx), False, add=True
|
|
)
|
|
)
|
|
output.extend(create_call_function(2, True))
|
|
elif (
|
|
isinstance(value, SymNodeVariable)
|
|
and value.python_type() == float
|
|
and not self.tx.export
|
|
):
|
|
# This is a little unusual; force the output convention to be a
|
|
# Tensor here. Don't do this for export because this is
|
|
# apparently load bearing for export tests (but I am a bit
|
|
# doubtful it actually works in the real world)
|
|
# NB: It works to add_graph_output on a computed expression
|
|
# as_tensor here, because we memoize as_tensor calls on
|
|
# SymNodeVariable!
|
|
graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx))
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.extend(
|
|
[self.create_load_attr("item")] + create_call_function(0, True)
|
|
)
|
|
elif isinstance(
|
|
value,
|
|
(
|
|
TensorVariable,
|
|
SymNodeVariable,
|
|
UnspecializedPythonVariable,
|
|
NumpyNdarrayVariable,
|
|
),
|
|
):
|
|
graph_outputs_key = self.add_graph_output(value)
|
|
|
|
if isinstance(value, NumpyNdarrayVariable):
|
|
self.load_import_from(utils.__name__, "to_numpy_helper")
|
|
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
|
|
if isinstance(value, NumpyNdarrayVariable):
|
|
output.extend(create_call_function(1, True))
|
|
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
|
|
output.extend(
|
|
[self.create_load_attr("item")] + create_call_function(0, True)
|
|
)
|
|
elif isinstance(value, NNModuleVariable):
|
|
parts = value.module_key.split(".")
|
|
if parts[0] in self.code_options["co_varnames"]:
|
|
output.append(self.create_load(parts[0]))
|
|
parts = parts[1:]
|
|
else:
|
|
assert self.root is not None
|
|
output.append(self.create_load_output(self.root))
|
|
for part in parts:
|
|
output.append(self.create_load_attr(part))
|
|
else:
|
|
self.uses[value] += 1
|
|
try:
|
|
self.call_reconstruct(value)
|
|
except NotImplementedError:
|
|
unimplemented(f"reconstruct: {value}")
|
|
if allow_cache and value in self.tempvars:
|
|
self._output.append(create_dup_top())
|
|
self.add_cache(value)
|
|
|
|
self.top_of_stack = value
|
|
|
|
def add_graph_output(self, value):
|
|
graph_outputs_key = id(value.as_proxy())
|
|
if graph_outputs_key not in self.graph_outputs:
|
|
self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
|
|
len(self.graph_outputs), value
|
|
)
|
|
return graph_outputs_key
|
|
|
|
def load_graph_output(self, index):
|
|
output = self._output
|
|
output.append(self.create_load(self.graph_output_var))
|
|
output.append(self._create_load_const(index))
|
|
output.append(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def add_cache(self, value):
|
|
var = self.new_var()
|
|
self.tempvars[value] = var
|
|
if value.mutable_local:
|
|
self.tempvars[value.mutable_local] = var
|
|
self._output.append(self.create_store(var))
|
|
|
|
def foreach(self, items):
|
|
for i in items:
|
|
self(i)
|
|
|
|
def setup_globally_cached(self, name, value, push_null):
|
|
"""Store value in a new global"""
|
|
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
|
|
f_globals = self.tx.f_globals
|
|
if name in f_globals:
|
|
assert id(f_globals[name]) == id(value)
|
|
else:
|
|
f_globals[name] = value
|
|
return [self.create_load_global(name, push_null, add=True)]
|
|
|
|
def clear_tos(self):
|
|
self.top_of_stack = None
|
|
|
|
def append_output(self, inst):
|
|
assert isinstance(inst, Instruction)
|
|
self._output.append(inst)
|
|
self.clear_tos()
|
|
|
|
def extend_output(self, insts):
|
|
assert all(isinstance(x, Instruction) for x in insts)
|
|
self._output.extend(insts)
|
|
self.clear_tos()
|
|
|
|
def get_instructions(self) -> List[Instruction]:
|
|
return self._output
|
|
|
|
def create_load(self, name) -> Instruction:
|
|
if name in self.cell_and_freevars():
|
|
return create_instruction("LOAD_DEREF", argval=name)
|
|
assert name in self.code_options["co_varnames"], f"{name} missing"
|
|
return create_instruction("LOAD_FAST", argval=name)
|
|
|
|
def create_load_closure(self, name) -> Instruction:
|
|
assert name in self.cell_and_freevars()
|
|
return create_instruction("LOAD_CLOSURE", argval=name)
|
|
|
|
def create_store(self, name) -> Instruction:
|
|
if name in self.cell_and_freevars():
|
|
return create_instruction("STORE_DEREF", argval=name)
|
|
assert name in self.code_options["co_varnames"]
|
|
return create_instruction("STORE_FAST", argval=name)
|
|
|
|
def create_load_global(self, name, push_null, add=False) -> Instruction:
|
|
if add:
|
|
self.tx.output.update_co_names(name)
|
|
assert name in self.code_options["co_names"], f"{name} not in co_names"
|
|
return create_load_global(name, push_null)
|
|
|
|
def create_load_const(self, value) -> Instruction:
|
|
assert is_safe_constant(value), f"unsafe constant {value}"
|
|
return self._create_load_const(value)
|
|
|
|
def _create_load_const(self, value) -> Instruction:
|
|
return create_instruction("LOAD_CONST", argval=value)
|
|
|
|
create_load_output = _create_load_const
|
|
|
|
def create_load_method(self, name):
|
|
self.tx.output.update_co_names(name)
|
|
return create_load_method(name)
|
|
|
|
def load_method(self, name):
|
|
self.append_output(self.create_load_method(name))
|
|
|
|
def call_method(self, nargs):
|
|
self.extend_output(create_call_method(nargs))
|
|
|
|
def create_load_attr(self, name) -> Instruction:
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] += (name,)
|
|
return create_load_attr(name)
|
|
|
|
def load_attr(self, name):
|
|
self.append_output(self.create_load_attr(name))
|
|
|
|
def create_load_attrs(self, names):
|
|
return [self.create_load_attr(name) for name in names.split(".")]
|
|
|
|
def create_store_attr(self, name) -> Instruction:
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] += (name,)
|
|
return create_instruction("STORE_ATTR", argval=name)
|
|
|
|
def store_attr(self, name):
|
|
self.append_output(self.create_store_attr(name))
|
|
|
|
def load_function_name(self, fn_name, push_null, num_on_stack=0):
|
|
"""Load the global fn_name on the stack num_on_stack down"""
|
|
output = []
|
|
if push_null and sys.version_info >= (3, 11):
|
|
output.extend(
|
|
[create_instruction("PUSH_NULL"), *self.rot_n(num_on_stack + 1)]
|
|
)
|
|
output.extend(
|
|
[
|
|
self.create_load_global(fn_name, False, add=True),
|
|
*self.rot_n(num_on_stack + 1),
|
|
]
|
|
)
|
|
return output
|
|
|
|
def rot_n(self, n):
|
|
try:
|
|
return create_rot_n(n)
|
|
except AttributeError:
|
|
# desired rotate bytecode doesn't exist, generate equivalent bytecode
|
|
return [
|
|
create_instruction("BUILD_TUPLE", arg=n),
|
|
self._create_load_const(rot_n_helper(n)),
|
|
*create_rot_n(2),
|
|
create_instruction("CALL_FUNCTION_EX", arg=0),
|
|
create_instruction("UNPACK_SEQUENCE", arg=n),
|
|
]
|
|
|
|
def pop_null(self):
|
|
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
|
|
# nop function, calling it (which consumes the null), and popping the result.
|
|
assert sys.version_info >= (3, 11)
|
|
return [
|
|
self._create_load_const(lambda: None),
|
|
*create_call_function(0, False),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
|
|
def pop_top(self):
|
|
self.append_output(create_instruction("POP_TOP"))
|
|
|
|
def call_function(self, nargs: int, push_null: bool):
|
|
self.extend_output(create_call_function(nargs, push_null=push_null))
|
|
|
|
def dup_top(self):
|
|
self.append_output(create_dup_top())
|
|
|
|
def store(self, varname):
|
|
self.append_output(self.create_store(varname))
|
|
|
|
def make_function_with_closure(
|
|
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
|
|
):
|
|
freevars = code.co_freevars
|
|
assert freevars
|
|
output = self._output
|
|
if sys.version_info >= (3, 11) and push_null:
|
|
output.append(create_instruction("PUSH_NULL"))
|
|
output.extend(self.rot_n(num_on_stack + 1))
|
|
for var in freevars:
|
|
assert var in self.cell_and_freevars()
|
|
output.append(create_instruction("LOAD_CLOSURE", argval=var))
|
|
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
|
|
output.append(self.create_load_const(code))
|
|
if sys.version_info < (3, 11):
|
|
output.append(self.create_load_const(fn_name))
|
|
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
|
|
output.extend(self.rot_n(num_on_stack + 1))
|
|
self.clear_tos()
|
|
|
|
def create_load_python_module(self, mod, push_null) -> Instruction:
|
|
"""
|
|
Generate a LOAD_GLOBAL instruction to fetch a given python module.
|
|
"""
|
|
output = self.tx.output
|
|
global_scope = output.global_scope
|
|
name = re.sub(r"^.*[.]", "", mod.__name__)
|
|
if global_scope.get(name, None) is mod:
|
|
return self.create_load_global(name, push_null, add=True)
|
|
prefix = f"___module_{name}"
|
|
global_name = self.tx.output.install_global_by_id(prefix, mod)
|
|
return self.create_load_global(global_name, push_null, add=True)
|
|
|
|
def make_call_generated_code(self, fn_name: str) -> None:
|
|
"""Call the generated code function stored in fn_name"""
|
|
self.extend_output(self.load_function_name(fn_name, True))
|
|
|
|
graphargs = self.tx.output.graphargs
|
|
for arg in graphargs:
|
|
if arg.pass_arg_as_tensor:
|
|
self.extend_output(
|
|
[
|
|
self.create_load_python_module(torch, True),
|
|
self.create_load_attr("as_tensor"),
|
|
]
|
|
)
|
|
self.call_reconstruct(arg)
|
|
self.extend_output(create_call_function(1, False))
|
|
else:
|
|
self.call_reconstruct(arg)
|
|
|
|
self.extend_output(create_call_function(len(graphargs), False))
|
|
|
|
def load_import_from(self, module_name, object_name) -> None:
|
|
self(AttrSource(self.tx.import_source(module_name), object_name))
|
|
|
|
def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]:
|
|
if sys.version_info >= (3, 11):
|
|
output = create_call_function(nargs, push_null)
|
|
if sys.version_info >= (3, 12):
|
|
idx = -1
|
|
expected_inst = "CALL"
|
|
else:
|
|
idx = -2
|
|
expected_inst = "PRECALL"
|
|
assert output[idx].opname == expected_inst
|
|
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
|
|
output.insert(idx, kw_names_inst)
|
|
return output
|
|
return [
|
|
self.create_load_const(kw_names),
|
|
create_instruction("CALL_FUNCTION_KW", arg=nargs),
|
|
]
|
|
|
|
def create_delete(self, value) -> Instruction:
|
|
return create_instruction("DELETE_FAST", argval=value)
|