mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90665 Approved by: https://github.com/voznesenskym
769 lines
28 KiB
Python
769 lines
28 KiB
Python
import collections
|
|
import copy
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
import re
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
OrderedDict,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
import sympy
|
|
from typing_extensions import Protocol
|
|
|
|
import torch.nn
|
|
from torch import fx
|
|
from torch._guards import Guard
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
from . import config, logging as torchdynamo_logging, variables
|
|
from .bytecode_transformation import create_instruction, Instruction, unique_id
|
|
from .codegen import PyCodegen
|
|
from .exc import BackendCompilerFailed, unimplemented
|
|
from .guards import GuardBuilder
|
|
from .mutation_guard import is_dynamic_nn_module
|
|
from .side_effects import SideEffects
|
|
from .source import ConstantSource, LocalSource, Source
|
|
from .utils import (
|
|
assert_no_fake_params_or_buffers,
|
|
checkpoint_params,
|
|
CleanupHook,
|
|
clone_inputs,
|
|
count_calls,
|
|
counters,
|
|
format_graph_tabular,
|
|
same,
|
|
)
|
|
from .variables.base import VariableTracker
|
|
from .variables.builder import GraphArg, VariableBuilder, wrap_fx_proxy
|
|
from .variables.nn_module import NNModuleVariable
|
|
from .variables.tensor import (
|
|
DynamicShapeVariable,
|
|
TensorVariable,
|
|
UnspecializedNumpyVariable,
|
|
UnspecializedPythonVariable,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# TODO: I think this accepts int arguments too
|
|
class CompiledFn(Protocol):
|
|
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
...
|
|
|
|
|
|
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
|
|
|
|
|
|
class OutputGraphState(NamedTuple):
|
|
graphargs: List[GraphArg]
|
|
guards: Set[Guard]
|
|
nn_modules: Optional[Dict[str, torch.nn.Module]]
|
|
side_effects: SideEffects
|
|
timestamp: int
|
|
name_to_input: OrderedDict[str, Optional[fx.Proxy]]
|
|
|
|
def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
|
|
for k in self._fields:
|
|
if k == "side_effects":
|
|
r = self.side_effects.diff(other.side_effects)
|
|
if r is not None:
|
|
return r
|
|
continue
|
|
|
|
sv = getattr(self, k)
|
|
ov = getattr(other, k)
|
|
if sv != ov:
|
|
return f"{prefix}{k} mismatch: {sv} != {ov}"
|
|
return None
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _step_logger():
|
|
return torchdynamo_logging.get_step_logger(log)
|
|
|
|
|
|
@dataclass
|
|
class GraphCompileReason:
|
|
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
|
|
|
|
reason: str
|
|
user_stack: List[traceback.FrameSummary]
|
|
|
|
|
|
def _get_gen_rand_values_fn(random_calls):
|
|
def _gen_rand_values():
|
|
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
|
|
|
|
return _gen_rand_values
|
|
|
|
|
|
class FakeRootModule(torch.nn.Module):
|
|
"""Trick the constructor of fx.GraphModule"""
|
|
|
|
def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
|
|
super(FakeRootModule, self).__init__()
|
|
for k, v in nn_modules.items():
|
|
setattr(self, k, v)
|
|
|
|
def __repr__(self):
|
|
return "FakeRootModule(...)"
|
|
|
|
|
|
class WrapperBackend:
|
|
def __init__(self, backend: CompilerFn, original_example_inputs):
|
|
self.backend: CompilerFn = backend
|
|
self.original_example_inputs = original_example_inputs
|
|
|
|
@property
|
|
def example_inputs(self):
|
|
return clone_inputs(self.original_example_inputs)
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
|
|
|
self.restore = checkpoint_params(gm)
|
|
self.gm = gm
|
|
copy_gm = copy.deepcopy(self.gm)
|
|
self.candidate = self.backend(copy_gm, self.original_example_inputs)
|
|
|
|
if self.candidate is None or self.candidate is self.gm.forward:
|
|
return self.gm.forward
|
|
|
|
if not config.verify_correctness:
|
|
return self.candidate
|
|
|
|
# if verify_correctness=True
|
|
try:
|
|
correct = self.gm.forward(*self.example_inputs)
|
|
result = self.candidate(*self.example_inputs)
|
|
|
|
# TODO: replace `same` function with the one in testing
|
|
if same(correct, result):
|
|
return self.candidate
|
|
|
|
raise RuntimeError(f"incorrect results of backend {self}")
|
|
return self.gm.forward
|
|
|
|
except Exception:
|
|
log.exception("error in verify_correctness")
|
|
raise
|
|
finally:
|
|
self.restore()
|
|
|
|
|
|
class OutputGraph(fx.Tracer):
|
|
"""
|
|
Wrapper class to hold outputs of InstructionTranslator. Mainly the
|
|
generated fx.Graph.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
f_globals: Dict[str, Any],
|
|
code_options: Dict[str, Any],
|
|
compiler_fn: CompilerFn,
|
|
root_tx,
|
|
):
|
|
super(OutputGraph, self).__init__()
|
|
|
|
self.graph = torch.fx.Graph()
|
|
self.graphargs: List[GraphArg] = []
|
|
# Although we prune unused graphargs before sending graphs to
|
|
# compilers, we may have legitimately triggered shape guards
|
|
# on "unused" inputs that we must keep track of. So after
|
|
# remove_unused_graphargs is called, orig_graphargs and
|
|
# graphargs no longer alias; orig_graphargs is the original
|
|
# graphargs, and graphargs is the pruned list. Guard creation
|
|
# should use original graphargs.
|
|
self.orig_graphargs: List[GraphArg] = self.graphargs
|
|
self.guards: Set[Guard] = set()
|
|
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
|
|
self.side_effects = SideEffects()
|
|
self.code_options = dict(code_options)
|
|
self.output_instructions: List[Instruction] = []
|
|
# used to track nodes that are added between calls of copy_graphstate
|
|
# and restore_graphstate
|
|
self.timestamp = 0
|
|
# Node => computed real value (see utils.get_real_value)
|
|
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
|
|
|
|
# Not checkpointed
|
|
self.compiler_fn: CompilerFn = compiler_fn
|
|
self.root_globals = f_globals
|
|
self.root_tx = root_tx
|
|
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
|
|
|
self._current_tx: List[InstructionTranslatorBase] = []
|
|
self.cleanups: List[CleanupHook] = []
|
|
self.should_exit = False
|
|
self.random_values_var = None
|
|
self.initial_random_state = ()
|
|
self.unspec_variable_map: Dict[
|
|
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
|
|
] = {}
|
|
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
|
|
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
|
|
|
|
# Enables creating unique node names by tracking
|
|
# all current placeholder node names
|
|
self.name_to_input: OrderedDict[
|
|
str, Optional[fx.Proxy]
|
|
] = collections.OrderedDict()
|
|
|
|
@property
|
|
def output(self):
|
|
return self
|
|
|
|
@property
|
|
def fake_mode(self):
|
|
return self.root_tx.fake_mode
|
|
|
|
def push_tx(self, tx):
|
|
self._current_tx.append(tx)
|
|
|
|
def pop_tx(self):
|
|
return self._current_tx.pop()
|
|
|
|
@property
|
|
def current_tx(self):
|
|
return self.root_tx if not self._current_tx else self._current_tx[-1]
|
|
|
|
def copy_graphstate(self) -> OutputGraphState:
|
|
"""Create a checkpoint of the current state by copying everything"""
|
|
assert self.nn_modules is not None
|
|
state = OutputGraphState(
|
|
list(self.graphargs),
|
|
set(self.guards),
|
|
dict(self.nn_modules),
|
|
self.side_effects.clone(),
|
|
self.timestamp,
|
|
self.name_to_input.copy(),
|
|
)
|
|
self.timestamp += 1
|
|
return state
|
|
|
|
def restore_graphstate(self, state: OutputGraphState):
|
|
"""Restore a checkpoint created by self.copy_graphstate()"""
|
|
(
|
|
self.graphargs,
|
|
self.guards,
|
|
self.nn_modules,
|
|
self.side_effects,
|
|
self.timestamp,
|
|
self.name_to_input,
|
|
) = state
|
|
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
|
|
removed_nodes = 0
|
|
for node in reversed(list(self.graph.nodes)):
|
|
if node.meta["creation_timestamp"] > self.timestamp:
|
|
# Erasing node alone does not remove the meta information
|
|
# So, remove the help tensor explicitly
|
|
if "example_value" in node.meta:
|
|
del node.meta["example_value"]
|
|
self.remove_node(node)
|
|
self.real_value_cache.pop(node, None)
|
|
removed_nodes += 1
|
|
log.debug(f"restore_graphstate: removed {removed_nodes} nodes")
|
|
|
|
def count_calls(self):
|
|
return count_calls(self.graph)
|
|
|
|
def get_submodule(self, keys):
|
|
assert keys
|
|
obj = self.nn_modules
|
|
for k in keys.split("."):
|
|
if isinstance(obj, dict):
|
|
obj = obj[k]
|
|
else:
|
|
obj = getattr(obj, k)
|
|
return obj
|
|
|
|
def create_graph_input(self, name, type_expr=None):
|
|
# unique
|
|
if name in self.name_to_input:
|
|
for i in itertools.count():
|
|
if f"{name}_{i}" not in self.name_to_input:
|
|
name = f"{name}_{i}"
|
|
break
|
|
|
|
if self.name_to_input:
|
|
prev_name = next(reversed(self.name_to_input))
|
|
ctx = self.graph.inserting_after(self.name_to_input[prev_name])
|
|
else:
|
|
ctx = self.graph.inserting_before(None)
|
|
with ctx:
|
|
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
|
|
self.name_to_input[name] = proxy.node
|
|
return proxy
|
|
|
|
def new_var(self, name="tmp"):
|
|
existing = set(self.code_options["co_varnames"])
|
|
for i in itertools.count():
|
|
var = f"___{name}_{i}"
|
|
if var not in existing:
|
|
self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
|
|
var,
|
|
)
|
|
return var
|
|
|
|
def update_co_names(self, name):
|
|
"""Ensure self.code_options.co_names contains name"""
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] = tuple(self.code_options["co_names"]) + (
|
|
name,
|
|
)
|
|
|
|
def register_attr_or_module(
|
|
self, target: Union[torch.nn.Module, torch.Tensor, Any], *names, **options
|
|
):
|
|
if is_dynamic_nn_module(target):
|
|
return variables.UnspecializedNNModuleVariable(target, **options)
|
|
|
|
options = dict(options)
|
|
options["guards"] = set(options.get("guards", []))
|
|
source: Source = options.get("source", None)
|
|
if isinstance(target, torch.Tensor):
|
|
if source:
|
|
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
|
|
|
|
def wrap_name(module_key):
|
|
return wrap_fx_proxy(
|
|
self,
|
|
self.create_proxy("get_attr", module_key, tuple(), {}),
|
|
example_value=target,
|
|
**options,
|
|
)
|
|
|
|
elif isinstance(target, torch.nn.Module):
|
|
assert isinstance(target, torch.nn.Module)
|
|
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
|
|
|
|
def wrap_name(module_key):
|
|
return NNModuleVariable(type(target), module_key, **options)
|
|
|
|
elif isinstance(target, (torch.SymInt, torch.SymFloat)):
|
|
# HACKY CODE REGION BEGIN
|
|
# WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
|
|
# This ultimately gets written to self.nn_modules, which is unfortunate
|
|
# Attrs that are tenors and symints and such need to be migrated to have their
|
|
# own storage
|
|
# alas, this is like this for now
|
|
self.intermediary_symbols.update({target.get_pyobj().expr: None})
|
|
|
|
def wrap_name(module_key):
|
|
return DynamicShapeVariable.create(
|
|
self,
|
|
self.create_proxy("get_attr", module_key, tuple(), {}),
|
|
dyn_shape=target,
|
|
**options,
|
|
)
|
|
|
|
# HACKY CODE REGION END
|
|
else:
|
|
|
|
def wrap_name(module_key):
|
|
self.output.update_co_names(module_key)
|
|
self.root_globals[module_key] = target
|
|
return VariableBuilder(self, ConstantSource(source_name=module_key))(
|
|
target
|
|
)
|
|
|
|
assert self.nn_modules is not None
|
|
for k, v in self.nn_modules.items():
|
|
if v is target:
|
|
# it already exists
|
|
return wrap_name(k)
|
|
|
|
# create a new unique name
|
|
name = "_".join(map(str, names))
|
|
# e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
|
|
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
|
|
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
|
|
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
|
|
|
if not name or not name[0].isalpha():
|
|
name = "sub" + name
|
|
base = name
|
|
for i in itertools.count():
|
|
if name not in self.nn_modules:
|
|
self.nn_modules[name] = target
|
|
return wrap_name(name)
|
|
name = f"{base}_{i}"
|
|
|
|
raise AssertionError("unreachable")
|
|
|
|
def compile_subgraph(
|
|
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
|
|
):
|
|
"""
|
|
Generate a subgraph to continue execution on user code.
|
|
Automatically restore live variables.
|
|
"""
|
|
from .eval_frame import disable
|
|
|
|
self.partial_convert = partial_convert
|
|
self.compile_subgraph_reason = reason
|
|
|
|
log.debug(f"COMPILING GRAPH due to {reason}")
|
|
|
|
if not all(block.can_restore() for block in tx.block_stack):
|
|
unimplemented("compile_subgraph with block_depth != 0")
|
|
|
|
for block in reversed(tx.block_stack):
|
|
block.exit(tx)
|
|
|
|
tx.prune_dead_locals()
|
|
stack_values = list(tx.stack)
|
|
assert self.nn_modules is not None
|
|
root = FakeRootModule(self.nn_modules)
|
|
|
|
# Add all the local vars to the "stack" so restore at the end
|
|
restore_vars = []
|
|
val_to_names: OrderedDict[
|
|
VariableTracker, List[str]
|
|
] = collections.OrderedDict()
|
|
if stack_values:
|
|
val_to_names[stack_values[-1]] = list()
|
|
for k, v in tx.symbolic_locals.items():
|
|
if isinstance(v.source, LocalSource) and v.source.name() == k:
|
|
continue # no need to restore initial state
|
|
if v not in val_to_names:
|
|
val_to_names[v] = list()
|
|
val_to_names[v].append(k)
|
|
for v in val_to_names.keys():
|
|
restore_vars.extend(val_to_names[v])
|
|
stack_values.extend([v] * len(val_to_names[v]))
|
|
|
|
# to handle random calls
|
|
if len(tx.random_calls) > 0:
|
|
random_calls_instructions = []
|
|
self.random_values_var = self.new_var("random_values")
|
|
rand_fn_name = unique_id("__gen_rand_values")
|
|
rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
|
|
self.install_global(rand_fn_name, rand_fn)
|
|
codegen = PyCodegen(tx, root)
|
|
random_calls_instructions.extend(
|
|
[
|
|
codegen.create_load_global("random", add=True),
|
|
codegen.create_load_attr("setstate"),
|
|
codegen.create_load_const(tx.output.initial_random_state),
|
|
create_instruction("CALL_FUNCTION", 1),
|
|
]
|
|
)
|
|
random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
|
|
random_calls_instructions.extend(
|
|
[
|
|
create_instruction("CALL_FUNCTION", 0),
|
|
codegen.create_store(tx.output.random_values_var),
|
|
]
|
|
)
|
|
self.add_output_instructions(random_calls_instructions)
|
|
|
|
if (
|
|
stack_values
|
|
and all(
|
|
not isinstance(
|
|
v, (UnspecializedNumpyVariable, UnspecializedPythonVariable)
|
|
)
|
|
for v in stack_values
|
|
)
|
|
and all(isinstance(x, TensorVariable) for x in stack_values)
|
|
and len(set(stack_values)) == len(stack_values)
|
|
and self.side_effects.is_empty()
|
|
):
|
|
|
|
# optimization to generate better code in a common case
|
|
self.add_output_instructions(
|
|
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
|
|
+ [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
|
|
)
|
|
else:
|
|
graph_output_var = self.new_var("graph_out")
|
|
pass1 = PyCodegen(tx, root, graph_output_var)
|
|
self.side_effects.codegen_save_tempvars(pass1)
|
|
pass1.foreach(stack_values)
|
|
self.side_effects.codegen_update_mutated(pass1)
|
|
|
|
# one more time now that we have established tempvars
|
|
pass2 = PyCodegen(
|
|
tx,
|
|
root,
|
|
graph_output_var,
|
|
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
|
|
)
|
|
self.side_effects.codegen_save_tempvars(pass2)
|
|
pass2.foreach(stack_values)
|
|
self.side_effects.codegen_update_mutated(pass2)
|
|
|
|
output = []
|
|
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
|
|
output.extend(
|
|
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
|
|
)
|
|
|
|
if len(pass2.graph_outputs) != 0:
|
|
output.append(pass2.create_store(graph_output_var))
|
|
else:
|
|
output.append(create_instruction("POP_TOP"))
|
|
self.add_output_instructions(output + pass2.get_instructions())
|
|
|
|
# restore all the live local vars
|
|
self.add_output_instructions(
|
|
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
|
|
)
|
|
|
|
def compile_and_call_fx_graph(self, tx, rv, root):
|
|
"""
|
|
Generate code from self.graph and return the Instruction()s to
|
|
call that generated code.
|
|
"""
|
|
from .eval_frame import disable
|
|
|
|
assert isinstance(rv, list)
|
|
assert isinstance(root, FakeRootModule)
|
|
for output in rv:
|
|
self.guards.update(output.guards)
|
|
|
|
self.create_node(
|
|
"output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
|
|
)
|
|
self.remove_unused_graphargs()
|
|
ncalls = count_calls(self.graph)
|
|
counters["stats"]["calls_captured"] += ncalls
|
|
counters["stats"]["fusions_possible"] += ncalls - 1
|
|
|
|
if config.dynamic_propagation:
|
|
# free a bit of memory
|
|
for node in self.graph.nodes:
|
|
if "example_value" in node.meta:
|
|
del node.meta["example_value"]
|
|
self.real_value_cache.clear()
|
|
|
|
gm = fx.GraphModule(root, self.graph)
|
|
gm.recompile()
|
|
gm.compile_subgraph_reason = self.compile_subgraph_reason
|
|
name = unique_id("__compiled_fn")
|
|
|
|
assert_no_fake_params_or_buffers(gm)
|
|
compiled_fn = self.call_user_compiler(gm)
|
|
compiled_fn = disable(compiled_fn)
|
|
|
|
counters["stats"]["unique_graphs"] += 1
|
|
self.install_global(name, compiled_fn)
|
|
|
|
try:
|
|
# the call to tabulate can cause a lot of memory to be allocated
|
|
if config.log_level <= logging.INFO and config.output_code:
|
|
graph_str = (
|
|
gm.print_readable()
|
|
if config.output_graph_code
|
|
else format_graph_tabular(gm.graph)
|
|
)
|
|
log.log(
|
|
logging.INFO,
|
|
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
|
|
)
|
|
except ImportError:
|
|
log.warning(
|
|
"Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
|
|
"which could not be found on this machine. Run `pip "
|
|
"install tabulate` to install the library."
|
|
)
|
|
|
|
cg = PyCodegen(tx)
|
|
cg.make_call_generated_code(name)
|
|
return cg.get_instructions()
|
|
|
|
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
|
|
try:
|
|
name = (
|
|
self.compiler_fn.__name__
|
|
if hasattr(self.compiler_fn, "__name__")
|
|
else ""
|
|
)
|
|
_step_logger()(logging.INFO, f"calling compiler function {name}")
|
|
compiler_fn = self.compiler_fn
|
|
# WrapperBackend needs real inputs, for now, to verify correctness
|
|
if config.verify_correctness:
|
|
compiler_fn = WrapperBackend(compiler_fn, self.example_inputs())
|
|
|
|
# NOTE: [Real Tensors in Accuracy Evaluation]
|
|
#
|
|
# Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs()
|
|
# call to compiler_fn below. At runtime, backends use real tensors.
|
|
#
|
|
# This should be a strong invariant we hold across all backends,
|
|
# and generally, it is. However, for accuracy evaluation, we need real tensors at compile time,
|
|
# for now, due to the unfortunate setup described below.
|
|
#
|
|
# Due to the nature of how we invoke comparison as a backend in two different ways:
|
|
#
|
|
# (1) Less bad, but still worth rewriting, WrapperBackend above, which takes
|
|
# real inputs for its ctor. see the config.verify_correctnes above.
|
|
#
|
|
# (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as
|
|
# a true backend, and therefore needs to be compiled with real inputs. This is made trickier
|
|
# by the fact that the minifier will spawn new processes during minification. As such, we have
|
|
# created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned
|
|
# as part of accuracy minification. This flag is not a contract, and ideally will not be here long.
|
|
#
|
|
# The longer term PoR is to:
|
|
# (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same
|
|
# correctness and accuracy logic, so as not to have two different ways of doing the same thing.
|
|
#
|
|
# (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to
|
|
# pass real tensors to it at compile time.
|
|
is_top_level_minifying = (
|
|
config.repro_after is not None and config.repro_level == 4
|
|
)
|
|
if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
|
|
compiled_fn = compiler_fn(gm, self.example_inputs())
|
|
elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
|
|
compiled_fn = compiler_fn(gm, self.example_inputs())
|
|
else:
|
|
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
|
|
_step_logger()(logging.INFO, f"done compiler function {name}")
|
|
assert callable(compiled_fn), "compiler_fn did not return callable"
|
|
except Exception as e:
|
|
compiled_fn = gm.forward
|
|
raise BackendCompilerFailed(self.compiler_fn, e) from e
|
|
return compiled_fn
|
|
|
|
def fake_example_inputs(self) -> List[torch.Tensor]:
|
|
result = []
|
|
for arg in self.graphargs:
|
|
example = arg.get_fake_examples()
|
|
if example is not None:
|
|
result.extend(example)
|
|
else:
|
|
# Fallback, in case fake_tensor was not set
|
|
# Particularly for graph args that are not tensors
|
|
result.extend(arg.get_examples())
|
|
return result
|
|
|
|
def example_inputs(self) -> List[torch.Tensor]:
|
|
result = []
|
|
for arg in self.graphargs:
|
|
result.extend(arg.get_examples())
|
|
return result
|
|
|
|
def remove_unused_graphargs(self) -> None:
|
|
for node in reversed(list(self.graph.nodes)):
|
|
if len(list(node.users)) == 0:
|
|
if node.op == "get_attr":
|
|
self.remove_node(node)
|
|
elif node.op == "call_function" and node.target is operator.getitem:
|
|
self.remove_node(node)
|
|
|
|
expanded_graphargs = []
|
|
for arg in self.graphargs:
|
|
expanded_graphargs.extend([arg] * len(arg))
|
|
arg.uses = 0
|
|
|
|
for node, arg in zip(self.graph.nodes, expanded_graphargs):
|
|
assert node.op == "placeholder"
|
|
arg.uses += len(node.users)
|
|
|
|
for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
|
|
if arg.uses == 0:
|
|
log.debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")
|
|
if "example_value" in node.meta:
|
|
del node.meta["example_value"]
|
|
self.remove_node(node)
|
|
self.real_value_cache.pop(node, None)
|
|
|
|
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
|
|
# NB: self.orig_graphargs is the original graphargs
|
|
|
|
def add_output_instructions(self, prefix: List[Instruction]) -> None:
|
|
"""
|
|
We call this on the creation of a new compiled subgraph that is inserted
|
|
before user code.
|
|
"""
|
|
self.output_instructions.extend(prefix)
|
|
self.should_exit = True
|
|
|
|
def install_global(self, name, value) -> None:
|
|
self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
|
|
|
|
def cleanup(self) -> None:
|
|
# There is a reference cycle between tracer and OutputGraph, causing
|
|
# some of the tensor objects to be held alive for longer than necessary.
|
|
|
|
# Clear cache for conversion of real -> fake tensors
|
|
self.root_tx.fake_mode.fake_tensor_converter = None
|
|
self.root_tx = None
|
|
|
|
# Note: generated fx graph will hold a reference to the nn_module,
|
|
# So depending on the backend they may not be released
|
|
self.nn_modules = None
|
|
|
|
# Cleanup graphargs
|
|
for graph_arg in self.graphargs:
|
|
graph_arg.erase()
|
|
|
|
for node in self.graph.nodes:
|
|
if "example_value" in node.meta:
|
|
del node.meta["example_value"]
|
|
self.real_value_cache.clear()
|
|
self.name_to_input.clear()
|
|
|
|
def create_proxy(
|
|
self,
|
|
kind,
|
|
target,
|
|
args,
|
|
kwargs,
|
|
name=None,
|
|
type_expr=None,
|
|
proxy_factory_fn=None,
|
|
):
|
|
rv = super().create_proxy(
|
|
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
|
|
)
|
|
|
|
# append stack trace to fx node
|
|
tx = self.current_tx
|
|
|
|
nn_module_stack = tx.nn_module_stack
|
|
if nn_module_stack:
|
|
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
|
|
|
|
frame_summaries: List[traceback.FrameSummary] = []
|
|
while tx:
|
|
frame_summaries.append(tx.frame_summary())
|
|
tx = getattr(tx, "parent", None)
|
|
|
|
# official from_list stub doesn't have new-style type
|
|
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
|
|
|
|
# Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra
|
|
nn_module_stack_str = f"Module stack: {nn_module_stack}\n"
|
|
rv.node.stack_trace = nn_module_stack_str + " | ".join(msgs)
|
|
|
|
return rv
|
|
|
|
def create_node(self, *args, **kwargs):
|
|
node = super().create_node(*args, **kwargs)
|
|
node.meta["creation_timestamp"] = self.timestamp
|
|
return node
|
|
|
|
# Note: we did not override erase_node since
|
|
# we call self.graph.erase_node elsewhere
|
|
def remove_node(self, node):
|
|
self.graph.erase_node(node)
|
|
self.name_to_input.pop(node.name, None)
|