Add mypy checking for a few files in torch/_dynamo (#89731)

It's kind of intractable to enable mypy everywhere at the moment,
because there are a lot of errors, and also mypy is really slow
for some reason.  I just want enough types to explain the public
types for user compiler calls, going through typing the _C.dynamo
bindings along the way.  This is a first step for this.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89731
Approved by: https://github.com/suo
This commit is contained in:
Edward Z. Yang 2022-11-27 20:48:40 -08:00 committed by PyTorch MergeBot
parent 55e8b5c126
commit f45fe7de33
10 changed files with 214 additions and 133 deletions

View File

@ -141,6 +141,28 @@ init_command = [
'pyyaml==6.0',
]
[[linter]]
code = 'MYPYNOFOLLOW'
include_patterns = [
'torch/_dynamo/eval_frame.py',
'torch/_dynamo/convert_frame.py',
'torch/_dynamo/types.py',
'torch/_dynamo/output_graph.py',
'torch/_dynamo/optimizations/__init__.py',
'torch/_dynamo/optimizations/backends.py',
'torch/_dynamo/optimizations/training.py',
'torch/_C/_dynamo/**/*.py',
]
exclude_patterns = [
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy-nofollow.ini',
'--',
'@{{PATHSFILE}}'
]
[[linter]]
code = 'MYPYSTRICT'
include_patterns = [

View File

View File

@ -0,0 +1,10 @@
import types
from typing import Union
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def reset_code(code: types.CodeType) -> None: ...
def unsupported(obj1: object, obj2: object) -> object: ...
def skip_code(code: types.CodeType) -> None: ...
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...

View File

@ -6,7 +6,8 @@ import traceback
import types
import typing
import weakref
from typing import Callable
from traceback import FrameSummary
from typing import Callable, cast, Dict, List, Optional
import torch
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
@ -24,6 +25,7 @@ from .exc import (
Unsupported,
)
from .guards import CheckFunctionManager, GuardedCode
from .output_graph import OutputGraph
from .replay_record import ExecutionRecord
from .symbolic_convert import InstructionTranslator
from .utils import (
@ -106,7 +108,7 @@ def wrap_convert_context(fn):
torch.cuda.set_rng_state(cuda_rng_state)
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
_fn._torchdynamo_orig_callable = fn
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
return _fn
@ -123,7 +125,7 @@ def has_tensor_in_frame(frame):
if is_allowed(frame.f_globals[co_name]):
return True
seen_ids = dict()
seen_ids: Dict[int, bool] = dict()
def has_tensor(obj):
"""Recursively check if the obj has a tensor"""
@ -194,7 +196,7 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
msg += "".join(
traceback.format_list(
stack_above_dynamo + list(reversed(exc.real_stack))
stack_above_dynamo + list(reversed(get_real_stack(exc)))
)
)
msg += "\n"
@ -207,13 +209,18 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
return msg
def get_real_stack(exc) -> List[FrameSummary]:
assert hasattr(exc, "real_stack")
return cast(List[FrameSummary], exc.real_stack)
def augment_exc_message(exc, msg="\n"):
if (
hasattr(exc, "real_stack")
and len(exc.real_stack) > 0
and not (config.verbose and config.suppress_errors)
):
msg += f"\nfrom user code:\n {''.join(traceback.format_list(reversed(exc.real_stack[0:2])))}"
msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}"
if config.replay_record_enabled and hasattr(exc, "record_filename"):
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
@ -344,7 +351,7 @@ def convert_frame_assert(
def _compile(
code,
code: types.CodeType,
globals,
locals,
builtins,
@ -353,8 +360,8 @@ def _compile(
export,
guard_export_fn=None,
frame=None,
):
output = None
) -> Optional[GuardedCode]:
output: Optional[OutputGraph] = None
# from .utils import print_once; print_once(code.co_filename)
def transform(instructions, code_options):
@ -372,6 +379,7 @@ def _compile(
)
tracer.run()
output = tracer.output
assert output is not None
assert output.output_instructions
instructions[:] = output.output_instructions
code_options.update(output.code_options)
@ -400,7 +408,7 @@ def _compile(
output_codes.add(out_code)
log.log(
logging.CODE,
logging.CODE, # type: ignore[attr-defined]
format_bytecode(
"ORIGINAL BYTECODE",
code.co_name,
@ -410,7 +418,7 @@ def _compile(
),
)
log.log(
logging.CODE,
logging.CODE, # type: ignore[attr-defined]
format_bytecode(
"MODIFIED BYTECODE",
code.co_name,
@ -420,6 +428,7 @@ def _compile(
),
)
assert output is not None
assert output.guards is not None
CleanupManager.instance[out_code] = output.cleanups
check_fn = CheckFunctionManager(output, output.guards, locals, globals)
@ -428,7 +437,7 @@ def _compile(
guard_str = "GUARDS:\n"
guard_str += "\n".join([f" - {str(guard)}" for guard in sorted(output.guards)])
log.log(logging.CODE, guard_str)
log.log(logging.CODE, guard_str) # type: ignore[attr-defined]
if guard_export_fn is not None:
guard_export_fn(output.guards)
@ -464,7 +473,7 @@ def convert_frame(compiler_fn: typing.Callable, guard_export_fn=None):
raise
return None
_convert_frame._torchdynamo_orig_callable = compiler_fn
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
return _convert_frame

View File

@ -9,7 +9,9 @@ import threading
import traceback
import types
import warnings
from enum import Enum
from importlib import import_module
from typing import Optional, Tuple, TYPE_CHECKING, Union
from unittest.mock import patch
import torch
@ -17,31 +19,45 @@ import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.parallel.distributed import DistributedDataParallel
if TYPE_CHECKING:
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_eval_frame,
set_guard_error_hook,
set_guard_fail_hook,
skip_code,
unsupported,
)
else:
for name in dir(torch._C._dynamo.eval_frame):
if name.startswith("__"):
continue
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
from . import config, convert_frame, skipfiles, utils
from .exc import ResetRequired
from .mutation_guard import install_generation_tagging_init
from .optimizations.distributed import DDPOptimizer
from .output_graph import CompilerFn
from .types import DynamoCallback
from .utils import compile_times
log = logging.getLogger(__name__)
try:
from torch.fx.experimental import proxy_tensor
except ImportError:
proxy_tensor = None
from torch.fx.experimental import proxy_tensor
_eval_frame = torch._C._dynamo.eval_frame
set_eval_frame = _eval_frame.set_eval_frame
reset_code = _eval_frame.reset_code
unsupported = _eval_frame.unsupported
skip_code = _eval_frame.skip_code
set_guard_fail_hook = _eval_frame.set_guard_fail_hook
set_guard_error_hook = _eval_frame.set_guard_error_hook
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
unset = object()
# See https://github.com/python/typing/pull/240
class Unset(Enum):
token = 0
unset = Unset.token
compile_lock = threading.RLock()
most_recent_backend = None
most_recent_backend: Optional[CompilerFn] = None
class OptimizedModule(torch.nn.Module):
@ -113,7 +129,7 @@ def enable_dynamic(enable: bool = True):
class _TorchDynamoContext:
def __init__(
self,
callback,
callback: DynamoCallback,
on_enter=nothing,
backend_ctx_ctor=null_context,
patch_fn=nothing,
@ -123,8 +139,8 @@ class _TorchDynamoContext:
):
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback = callback
self.prior = unset
self.callback: DynamoCallback = callback
self.prior: Union[Unset, DynamoCallback] = unset
self.on_enter = on_enter
self.extra_ctx_ctor = backend_ctx_ctor
self.first_ctx = first_ctx
@ -146,6 +162,7 @@ class _TorchDynamoContext:
self.dynamic_ctx.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
assert self.prior is not unset
set_eval_frame(self.prior)
self.prior = unset
# TODO: This is totally not the right way to chain contexts manually
@ -198,13 +215,13 @@ class _TorchDynamoContext:
# hooks to properly handle inlining
if isinstance(self, DisableContext):
_fn._torchdynamo_disable = True
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
else:
_fn._torchdynamo_inline = fn
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
# Save the function pointer to find the original callable while nesting
# of decorators.
_fn._torchdynamo_orig_callable = fn
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
# If the function is called using torch._dynamo.optimize decorator, we
# should prevent any type of skipping.
@ -306,7 +323,7 @@ def catch_errors_wrapper(callback):
with compile_lock:
return callback(frame, cache_size)
catch_errors._torchdynamo_orig_callable = callback
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors
@ -510,7 +527,7 @@ def export(
graph = None
out_guards = None
graph_captured_input = None
graph_captured_result = None
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
def produce_matching(source_args, candidate_args):
matched_elements_positions = []
@ -559,6 +576,7 @@ def export(
nonlocal graph_captured_input
graph_captured_input = graph_inputs
assert graph is not None
graph_captured_result = graph(*graph_inputs)
return graph_captured_result
@ -585,6 +603,7 @@ def export(
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
assert graph_captured_result is not None
flat_both = list(graph_captured_result) + flat_args
matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
@ -710,8 +729,7 @@ class TorchPatcher:
torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
torch.distributions.Distribution.set_default_validate_args(False)
if proxy_tensor is not None:
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
optimizers = [
opt

View File

@ -805,7 +805,7 @@ def ___make_guard_fn({','.join(closure_vars.keys())}):
def guard_fail_hook(
guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool
):
) -> None:
"""
called whenever a guard fails.
"""

View File

@ -6,15 +6,18 @@ import os
import subprocess
import tempfile
from typing import Dict
import numpy as np
import torch
from ..output_graph import CompilerFn
from ..utils import identity
from .subgraph import SubGraph
log = logging.getLogger(__name__)
BACKENDS = dict()
BACKENDS: Dict[str, CompilerFn] = dict()
_NP_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
@ -130,7 +133,7 @@ def static_runtime(subgraph):
def onnxrt_common(subgraph, provider, onnx_filename=None):
import onnxruntime
import onnxruntime # type: ignore[import]
assert provider in onnxruntime.get_available_providers()
session = onnxruntime.InferenceSession(
@ -141,9 +144,9 @@ def onnxrt_common(subgraph, provider, onnx_filename=None):
create_outputs = subgraph.empty_outputs_factory()
is_cpu = subgraph.is_cpu
def _call(*args):
def _call(*initial_args):
binding = session.io_binding()
args = [a.contiguous() for a in args]
args = [a.contiguous() for a in initial_args]
for name, value in zip(input_names, args):
dev = value.device
binding.bind_input(
@ -228,7 +231,7 @@ def onnxrt(subgraph):
@functools.lru_cache(None)
def _init_tensorflow():
import tensorflow as tf
import tensorflow as tf # type: ignore[import]
# prevent tensorflow from eating all the GPU memory
gpus = tf.config.list_physical_devices("GPU")
@ -239,8 +242,8 @@ def _init_tensorflow():
@create_backend
def onnx2tf(subgraph):
import onnx
from onnx_tf.backend import prepare
import onnx # type: ignore[import]
from onnx_tf.backend import prepare # type: ignore[import]
tf = _init_tensorflow()
filename = subgraph.filename("tensorflow")
@ -253,8 +256,8 @@ def onnx2tf(subgraph):
tf_module = tf.saved_model.load(filename)
tf_module = tf.function(tf_module, jit_compile=True)
def run(*args):
args = [a.contiguous() for a in args]
def run(*i_args):
args = [a.contiguous() for a in i_args]
with tf.device(device):
outs = tf_module(
**{
@ -292,7 +295,7 @@ def taso(subgraph):
@create_backend
def ipex(subgraph, **kwargs):
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # type: ignore[import]
inputs = subgraph.example_inputs
model = subgraph.model
@ -321,12 +324,20 @@ def fx2trt(subgraph, **kwargs):
# TensorRT fails violently with an abort() on this
return None
from torch_tensorrt.fx.fx2trt import InputTensorSpec, TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.fx.fx2trt import ( # type: ignore[import]
InputTensorSpec,
TRTInterpreter,
)
from torch_tensorrt.fx.passes.lower_basic_pass import ( # type: ignore[import]
transform_setitem,
)
from torch_tensorrt.fx.tools.trt_splitter import ( # type: ignore[import]
TRTSplitter,
TRTSplitterSetting,
)
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer # type: ignore[import]
from torch_tensorrt.fx.trt_module import TRTModule # type: ignore[import]
from torch_tensorrt.fx.utils import LowerPrecision # type: ignore[import]
from .normalize import normalize_ir
@ -414,7 +425,7 @@ def torch2trt(subgraph):
# TensorRT fails violently with an abort() on this
return None
from torch2trt import torch2trt
from torch2trt import torch2trt # type: ignore[import]
inputs = subgraph.example_inputs
trt_mod = torch2trt(
@ -438,45 +449,6 @@ def tensorrt(subgraph):
return model
@create_backend
def onnx2tensorrt_alt(subgraph):
if subgraph.will_tensorrt_barf():
# TensorRT fails violently with an abort() on this
return None
import tensorrt as trt
from torch.fx.experimental.fx2trt.trt_module import TRTModule
inputs = subgraph.example_inputs
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
config = builder.create_builder_config()
assert isinstance(inputs, (list, tuple))
inputs = tuple(inputs)
input_names = subgraph.input_names
output_names = subgraph.output_names
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse(open(subgraph.onnx_filename, "rb").read())
for idx in range(parser.num_errors):
print(parser.get_error(idx))
assert success
config.max_workspace_size = 1 << 25
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
builder.max_batch_size = len(inputs[0])
engine = builder.build_engine(network, config)
assert engine
trt_mod = TRTModule(engine, input_names, output_names)
return subgraph.wrap_returns(trt_mod)
@create_backend
def cudagraphs(subgraph):
model = subgraph.model
@ -628,9 +600,9 @@ def tvm_compile_inner(
jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False
):
try:
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import tvm # type: ignore[import]
from tvm import relay # type: ignore[import]
from tvm.contrib import graph_executor # type: ignore[import]
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
@ -724,8 +696,8 @@ def tvm_compile_inner(
return torch.from_numpy(nd_tensor.numpy())
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
def exec_tvm(*args):
args = [a.contiguous() for a in args]
def exec_tvm(*i_args):
args = [a.contiguous() for a in i_args]
for idx, arg in enumerate(args, 0):
if arg.dim() != 0:
if arg.requires_grad:

View File

@ -148,7 +148,7 @@ class AotNop(AotAutogradStrategy):
DEBUG = False
return BACKENDS["aot_autograd"](
self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop
)
) # type: ignore[call-arg]
aot_eager = AotNop.compile_fn
@ -164,7 +164,7 @@ class AotTorchscript(AotAutogradStrategy):
return BACKENDS["aot_autograd"](
self.gm, self.example_inputs, fw_compiler=ts_compile
)
) # type: ignore[call-arg]
aot_ts = AotTorchscript.compile_fn
@ -214,7 +214,7 @@ class AotMemEfficientFusion(AotAutogradStrategy):
def candidate(self):
kwargs = mem_efficient_fusion_kwargs(use_decomps=True)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
@ -222,7 +222,7 @@ class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
def candidate(self):
kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
class AotInductorDebug(AotAutogradStrategy):
@ -247,7 +247,7 @@ class AotInductorDebug(AotAutogradStrategy):
min_cut_rematerialization_partition, compiler="inductor"
),
}
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
aot_inductor_debug = AotInductorDebug.compile_fn
@ -346,7 +346,7 @@ def create_nvprims_backend(*, executor):
fw_compiler=partial(prims_executor, executor=self.executor),
bw_compiler=partial(prims_executor, executor=self.executor),
partition_fn=disable(nvprims_fw_bw_partition_fn),
)
) # type: ignore[call-arg]
return NvPrims

View File

@ -7,7 +7,10 @@ import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Set, Tuple, Union
import sympy
from typing_extensions import Protocol
import torch.nn
from torch import fx
@ -16,8 +19,8 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config, logging as torchdynamo_logging, variables
from .bytecode_transformation import create_instruction, Instruction, unique_id
from .codegen import PyCodegen
from .exc import BackendCompilerFailed
from .guards import GuardBuilder
from .exc import BackendCompilerFailed, unimplemented
from .guards import Guard, GuardBuilder, TensorReference
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import ConstantSource, LocalSource, Source
@ -31,7 +34,8 @@ from .utils import (
format_graph_tabular,
same,
)
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.base import VariableTracker
from .variables.builder import GraphArg, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
DynamicShapeVariable,
@ -43,6 +47,15 @@ from .variables.tensor import (
log = logging.getLogger(__name__)
# TODO: I think this accepts int arguments too
class CompiledFn(Protocol):
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
...
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@ -75,27 +88,27 @@ class FakeRootModule(torch.nn.Module):
return "FakeRootModule(...)"
def wrap_compiler_fn(compiler_fn):
def wrap_compiler_fn(compiler_fn: CompilerFn) -> CompilerFn:
"""WrapperBackend if config.verify_correctness is True"""
if config.verify_correctness:
# wrap backend if verify_correctness is True
wrapper_backend_compiler_fn = WrapperBackend(compiler_fn)
wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn
wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
return wrapper_backend_compiler_fn
return compiler_fn
class WrapperBackend:
def __init__(self, backend=None):
self.backend = backend
def __init__(self, backend: CompilerFn):
self.backend: CompilerFn = backend
@property
def example_inputs(self):
return clone_inputs(self.original_example_inputs)
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.restore = checkpoint_params(gm)
self.original_example_inputs = clone_inputs(example_inputs)
@ -138,38 +151,42 @@ class OutputGraph(fx.Tracer):
self,
f_globals: Dict[str, Any],
code_options: Dict[str, Any],
compiler_fn: Callable,
compiler_fn: CompilerFn,
root_tx,
):
super(OutputGraph, self).__init__()
# Mutable state checkpointed by copy_graphstate()
self.graph = torch.fx.Graph()
self.graphargs = []
self.guards = set()
self.nn_modules = dict()
self.graphargs: List[GraphArg] = []
self.guards: Set[Guard] = set()
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions = []
self.output_instructions: List[Instruction] = []
# Node => computed real value (see utils.get_real_value)
self.real_value_cache = {}
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
# Not checkpointed
self.compiler_fn = compiler_fn
self.compiler_fn: CompilerFn = compiler_fn
self.root_globals = f_globals
self.root_tx = root_tx
self.cleanups = []
self.cleanups: List[CleanupHook] = []
self.should_exit = False
self.random_values_var = None
self.initial_random_state = ()
self.unspec_variable_map = {}
self.unspec_variable_map: Dict[
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
] = {}
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
self.tensor_id_to_sym_shape_ref = {}
self.intermediary_symbols = {}
self.tensor_id_to_sym_shape_ref: Dict[int, Set[TensorReference]] = {}
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
# Enables creating unique node names by tracking
# all current placeholder node names
self.name_to_input = collections.OrderedDict()
self.name_to_input: OrderedDict[
str, Optional[fx.Proxy]
] = collections.OrderedDict()
@property
def output(self):
@ -181,6 +198,7 @@ class OutputGraph(fx.Tracer):
def copy_graphstate(self):
"""Create a checkpoint of the current state by copying everything"""
assert self.nn_modules is not None
graph_nodes = set(self.graph.nodes)
return (
graph_nodes,
@ -313,6 +331,7 @@ class OutputGraph(fx.Tracer):
target
)
assert self.nn_modules is not None
for k, v in self.nn_modules.items():
if v is target:
# it already exists
@ -356,11 +375,14 @@ class OutputGraph(fx.Tracer):
tx.prune_dead_locals()
stack_values = list(tx.stack)
assert self.nn_modules is not None
root = FakeRootModule(self.nn_modules)
# Add all the local vars to the "stack" so restore at the end
restore_vars = []
val_to_names = collections.OrderedDict()
val_to_names: OrderedDict[
VariableTracker, List[str]
] = collections.OrderedDict()
if stack_values:
val_to_names[stack_values[-1]] = list()
for k, v in tx.symbolic_locals.items():
@ -494,7 +516,7 @@ class OutputGraph(fx.Tracer):
# the call to tabulate can cause a lot of memory to be allocated
if config.log_level <= logging.INFO:
log.log(
logging.CODE,
logging.CODE, # type: ignore[attr-defined]
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n",
)
except ImportError:
@ -508,7 +530,7 @@ class OutputGraph(fx.Tracer):
cg.make_call_generated_code(name)
return cg.get_instructions()
def call_user_compiler(self, gm):
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
try:
name = (
self.compiler_fn.__name__
@ -527,13 +549,13 @@ class OutputGraph(fx.Tracer):
raise BackendCompilerFailed(self.compiler_fn, e) from e
return compiled_fn
def example_inputs(self):
def example_inputs(self) -> List[torch.Tensor]:
result = []
for arg in self.graphargs:
result.extend(arg.get_examples())
return result
def remove_unused_graphargs(self):
def remove_unused_graphargs(self) -> None:
for node in reversed(list(self.graph.nodes)):
if len(list(node.users)) == 0:
if node.op == "get_attr":
@ -560,7 +582,7 @@ class OutputGraph(fx.Tracer):
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
def add_output_instructions(self, prefix: List[Instruction]):
def add_output_instructions(self, prefix: List[Instruction]) -> None:
"""
We call this on the creation of a new compiled subgraph that is inserted
before user code.
@ -568,10 +590,10 @@ class OutputGraph(fx.Tracer):
self.output_instructions.extend(prefix)
self.should_exit = True
def install_global(self, name, value):
def install_global(self, name, value) -> None:
self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
def cleanup(self):
def cleanup(self) -> None:
# There is a reference cycle between tracer and OutputGraph, causing
# some of the tensor objects to be held alive for longer than necessary.
@ -620,7 +642,8 @@ class OutputGraph(fx.Tracer):
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
msgs = traceback.StackSummary.from_list(frame_summaries).format()
# official from_list stub doesn't have new-style type
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
# Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra
nn_module_stack_str = f"Module stack: {nn_module_stack}\n"

27
torch/_dynamo/types.py Normal file
View File

@ -0,0 +1,27 @@
import types
from typing import Any, Callable, Dict, Optional, Union
from typing_extensions import Protocol
from torch._dynamo.guards import GuardedCode
class DynamoCallbackFn(Protocol):
def __call__(
self, frame: types.FrameType, cache_size: int
) -> Optional[GuardedCode]:
...
DynamoCallback = Union[DynamoCallbackFn, None, bool]
class DynamoGuardHook(Protocol):
def __call__(
self,
guard_fn: Callable,
code: types.CodeType,
f_locals: Dict[str, Any],
last: bool,
) -> None:
...