mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add mypy checking for a few files in torch/_dynamo (#89731)
It's kind of intractable to enable mypy everywhere at the moment, because there are a lot of errors, and also mypy is really slow for some reason. I just want enough types to explain the public types for user compiler calls, going through typing the _C.dynamo bindings along the way. This is a first step for this. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89731 Approved by: https://github.com/suo
This commit is contained in:
parent
55e8b5c126
commit
f45fe7de33
|
|
@ -141,6 +141,28 @@ init_command = [
|
|||
'pyyaml==6.0',
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPYNOFOLLOW'
|
||||
include_patterns = [
|
||||
'torch/_dynamo/eval_frame.py',
|
||||
'torch/_dynamo/convert_frame.py',
|
||||
'torch/_dynamo/types.py',
|
||||
'torch/_dynamo/output_graph.py',
|
||||
'torch/_dynamo/optimizations/__init__.py',
|
||||
'torch/_dynamo/optimizations/backends.py',
|
||||
'torch/_dynamo/optimizations/training.py',
|
||||
'torch/_C/_dynamo/**/*.py',
|
||||
]
|
||||
exclude_patterns = [
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/mypy_linter.py',
|
||||
'--config=mypy-nofollow.ini',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPYSTRICT'
|
||||
include_patterns = [
|
||||
|
|
|
|||
0
torch/_C/_dynamo/__init__.pyi
Normal file
0
torch/_C/_dynamo/__init__.pyi
Normal file
10
torch/_C/_dynamo/eval_frame.pyi
Normal file
10
torch/_C/_dynamo/eval_frame.pyi
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import types
|
||||
from typing import Union
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
|
||||
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
||||
def reset_code(code: types.CodeType) -> None: ...
|
||||
def unsupported(obj1: object, obj2: object) -> object: ...
|
||||
def skip_code(code: types.CodeType) -> None: ...
|
||||
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
|
||||
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
||||
|
|
@ -6,7 +6,8 @@ import traceback
|
|||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Callable
|
||||
from traceback import FrameSummary
|
||||
from typing import Callable, cast, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
|
||||
|
|
@ -24,6 +25,7 @@ from .exc import (
|
|||
Unsupported,
|
||||
)
|
||||
from .guards import CheckFunctionManager, GuardedCode
|
||||
from .output_graph import OutputGraph
|
||||
from .replay_record import ExecutionRecord
|
||||
from .symbolic_convert import InstructionTranslator
|
||||
from .utils import (
|
||||
|
|
@ -106,7 +108,7 @@ def wrap_convert_context(fn):
|
|||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
|
||||
|
||||
_fn._torchdynamo_orig_callable = fn
|
||||
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
return _fn
|
||||
|
||||
|
||||
|
|
@ -123,7 +125,7 @@ def has_tensor_in_frame(frame):
|
|||
if is_allowed(frame.f_globals[co_name]):
|
||||
return True
|
||||
|
||||
seen_ids = dict()
|
||||
seen_ids: Dict[int, bool] = dict()
|
||||
|
||||
def has_tensor(obj):
|
||||
"""Recursively check if the obj has a tensor"""
|
||||
|
|
@ -194,7 +196,7 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
|
|||
|
||||
msg += "".join(
|
||||
traceback.format_list(
|
||||
stack_above_dynamo + list(reversed(exc.real_stack))
|
||||
stack_above_dynamo + list(reversed(get_real_stack(exc)))
|
||||
)
|
||||
)
|
||||
msg += "\n"
|
||||
|
|
@ -207,13 +209,18 @@ def format_error_msg(exc, code, record_filename=None, frame=None):
|
|||
return msg
|
||||
|
||||
|
||||
def get_real_stack(exc) -> List[FrameSummary]:
|
||||
assert hasattr(exc, "real_stack")
|
||||
return cast(List[FrameSummary], exc.real_stack)
|
||||
|
||||
|
||||
def augment_exc_message(exc, msg="\n"):
|
||||
if (
|
||||
hasattr(exc, "real_stack")
|
||||
and len(exc.real_stack) > 0
|
||||
and not (config.verbose and config.suppress_errors)
|
||||
):
|
||||
msg += f"\nfrom user code:\n {''.join(traceback.format_list(reversed(exc.real_stack[0:2])))}"
|
||||
msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}"
|
||||
|
||||
if config.replay_record_enabled and hasattr(exc, "record_filename"):
|
||||
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\
|
||||
|
|
@ -344,7 +351,7 @@ def convert_frame_assert(
|
|||
|
||||
|
||||
def _compile(
|
||||
code,
|
||||
code: types.CodeType,
|
||||
globals,
|
||||
locals,
|
||||
builtins,
|
||||
|
|
@ -353,8 +360,8 @@ def _compile(
|
|||
export,
|
||||
guard_export_fn=None,
|
||||
frame=None,
|
||||
):
|
||||
output = None
|
||||
) -> Optional[GuardedCode]:
|
||||
output: Optional[OutputGraph] = None
|
||||
|
||||
# from .utils import print_once; print_once(code.co_filename)
|
||||
def transform(instructions, code_options):
|
||||
|
|
@ -372,6 +379,7 @@ def _compile(
|
|||
)
|
||||
tracer.run()
|
||||
output = tracer.output
|
||||
assert output is not None
|
||||
assert output.output_instructions
|
||||
instructions[:] = output.output_instructions
|
||||
code_options.update(output.code_options)
|
||||
|
|
@ -400,7 +408,7 @@ def _compile(
|
|||
output_codes.add(out_code)
|
||||
|
||||
log.log(
|
||||
logging.CODE,
|
||||
logging.CODE, # type: ignore[attr-defined]
|
||||
format_bytecode(
|
||||
"ORIGINAL BYTECODE",
|
||||
code.co_name,
|
||||
|
|
@ -410,7 +418,7 @@ def _compile(
|
|||
),
|
||||
)
|
||||
log.log(
|
||||
logging.CODE,
|
||||
logging.CODE, # type: ignore[attr-defined]
|
||||
format_bytecode(
|
||||
"MODIFIED BYTECODE",
|
||||
code.co_name,
|
||||
|
|
@ -420,6 +428,7 @@ def _compile(
|
|||
),
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
assert output.guards is not None
|
||||
CleanupManager.instance[out_code] = output.cleanups
|
||||
check_fn = CheckFunctionManager(output, output.guards, locals, globals)
|
||||
|
|
@ -428,7 +437,7 @@ def _compile(
|
|||
guard_str = "GUARDS:\n"
|
||||
guard_str += "\n".join([f" - {str(guard)}" for guard in sorted(output.guards)])
|
||||
|
||||
log.log(logging.CODE, guard_str)
|
||||
log.log(logging.CODE, guard_str) # type: ignore[attr-defined]
|
||||
|
||||
if guard_export_fn is not None:
|
||||
guard_export_fn(output.guards)
|
||||
|
|
@ -464,7 +473,7 @@ def convert_frame(compiler_fn: typing.Callable, guard_export_fn=None):
|
|||
raise
|
||||
return None
|
||||
|
||||
_convert_frame._torchdynamo_orig_callable = compiler_fn
|
||||
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||
return _convert_frame
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ import threading
|
|||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
from typing import Optional, Tuple, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -17,31 +19,45 @@ import torch.utils._pytree as pytree
|
|||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||
reset_code,
|
||||
set_eval_frame,
|
||||
set_guard_error_hook,
|
||||
set_guard_fail_hook,
|
||||
skip_code,
|
||||
unsupported,
|
||||
)
|
||||
else:
|
||||
for name in dir(torch._C._dynamo.eval_frame):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
|
||||
|
||||
from . import config, convert_frame, skipfiles, utils
|
||||
from .exc import ResetRequired
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .optimizations.distributed import DDPOptimizer
|
||||
from .output_graph import CompilerFn
|
||||
from .types import DynamoCallback
|
||||
from .utils import compile_times
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
except ImportError:
|
||||
proxy_tensor = None
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
|
||||
_eval_frame = torch._C._dynamo.eval_frame
|
||||
set_eval_frame = _eval_frame.set_eval_frame
|
||||
reset_code = _eval_frame.reset_code
|
||||
unsupported = _eval_frame.unsupported
|
||||
skip_code = _eval_frame.skip_code
|
||||
set_guard_fail_hook = _eval_frame.set_guard_fail_hook
|
||||
set_guard_error_hook = _eval_frame.set_guard_error_hook
|
||||
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
||||
null_context = contextlib.nullcontext
|
||||
unset = object()
|
||||
|
||||
# See https://github.com/python/typing/pull/240
|
||||
class Unset(Enum):
|
||||
token = 0
|
||||
|
||||
|
||||
unset = Unset.token
|
||||
|
||||
compile_lock = threading.RLock()
|
||||
most_recent_backend = None
|
||||
most_recent_backend: Optional[CompilerFn] = None
|
||||
|
||||
|
||||
class OptimizedModule(torch.nn.Module):
|
||||
|
|
@ -113,7 +129,7 @@ def enable_dynamic(enable: bool = True):
|
|||
class _TorchDynamoContext:
|
||||
def __init__(
|
||||
self,
|
||||
callback,
|
||||
callback: DynamoCallback,
|
||||
on_enter=nothing,
|
||||
backend_ctx_ctor=null_context,
|
||||
patch_fn=nothing,
|
||||
|
|
@ -123,8 +139,8 @@ class _TorchDynamoContext:
|
|||
):
|
||||
super().__init__()
|
||||
assert callable(callback) or callback is False or callback is None
|
||||
self.callback = callback
|
||||
self.prior = unset
|
||||
self.callback: DynamoCallback = callback
|
||||
self.prior: Union[Unset, DynamoCallback] = unset
|
||||
self.on_enter = on_enter
|
||||
self.extra_ctx_ctor = backend_ctx_ctor
|
||||
self.first_ctx = first_ctx
|
||||
|
|
@ -146,6 +162,7 @@ class _TorchDynamoContext:
|
|||
self.dynamic_ctx.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
assert self.prior is not unset
|
||||
set_eval_frame(self.prior)
|
||||
self.prior = unset
|
||||
# TODO: This is totally not the right way to chain contexts manually
|
||||
|
|
@ -198,13 +215,13 @@ class _TorchDynamoContext:
|
|||
|
||||
# hooks to properly handle inlining
|
||||
if isinstance(self, DisableContext):
|
||||
_fn._torchdynamo_disable = True
|
||||
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||
else:
|
||||
_fn._torchdynamo_inline = fn
|
||||
_fn._torchdynamo_inline = fn # type: ignore[attr-defined]
|
||||
|
||||
# Save the function pointer to find the original callable while nesting
|
||||
# of decorators.
|
||||
_fn._torchdynamo_orig_callable = fn
|
||||
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
|
||||
# If the function is called using torch._dynamo.optimize decorator, we
|
||||
# should prevent any type of skipping.
|
||||
|
|
@ -306,7 +323,7 @@ def catch_errors_wrapper(callback):
|
|||
with compile_lock:
|
||||
return callback(frame, cache_size)
|
||||
|
||||
catch_errors._torchdynamo_orig_callable = callback
|
||||
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
|
||||
return catch_errors
|
||||
|
||||
|
||||
|
|
@ -510,7 +527,7 @@ def export(
|
|||
graph = None
|
||||
out_guards = None
|
||||
graph_captured_input = None
|
||||
graph_captured_result = None
|
||||
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
|
||||
def produce_matching(source_args, candidate_args):
|
||||
matched_elements_positions = []
|
||||
|
|
@ -559,6 +576,7 @@ def export(
|
|||
nonlocal graph_captured_input
|
||||
|
||||
graph_captured_input = graph_inputs
|
||||
assert graph is not None
|
||||
graph_captured_result = graph(*graph_inputs)
|
||||
return graph_captured_result
|
||||
|
||||
|
|
@ -585,6 +603,7 @@ def export(
|
|||
|
||||
flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
|
||||
|
||||
assert graph_captured_result is not None
|
||||
flat_both = list(graph_captured_result) + flat_args
|
||||
matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
|
||||
|
||||
|
|
@ -710,8 +729,7 @@ class TorchPatcher:
|
|||
torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
|
||||
torch.distributions.Distribution.set_default_validate_args(False)
|
||||
|
||||
if proxy_tensor is not None:
|
||||
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
|
||||
proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
|
||||
|
||||
optimizers = [
|
||||
opt
|
||||
|
|
|
|||
|
|
@ -805,7 +805,7 @@ def ___make_guard_fn({','.join(closure_vars.keys())}):
|
|||
|
||||
def guard_fail_hook(
|
||||
guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
called whenever a guard fails.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,15 +6,18 @@ import os
|
|||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from ..output_graph import CompilerFn
|
||||
|
||||
from ..utils import identity
|
||||
from .subgraph import SubGraph
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
BACKENDS = dict()
|
||||
BACKENDS: Dict[str, CompilerFn] = dict()
|
||||
_NP_DTYPE = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
|
|
@ -130,7 +133,7 @@ def static_runtime(subgraph):
|
|||
|
||||
|
||||
def onnxrt_common(subgraph, provider, onnx_filename=None):
|
||||
import onnxruntime
|
||||
import onnxruntime # type: ignore[import]
|
||||
|
||||
assert provider in onnxruntime.get_available_providers()
|
||||
session = onnxruntime.InferenceSession(
|
||||
|
|
@ -141,9 +144,9 @@ def onnxrt_common(subgraph, provider, onnx_filename=None):
|
|||
create_outputs = subgraph.empty_outputs_factory()
|
||||
is_cpu = subgraph.is_cpu
|
||||
|
||||
def _call(*args):
|
||||
def _call(*initial_args):
|
||||
binding = session.io_binding()
|
||||
args = [a.contiguous() for a in args]
|
||||
args = [a.contiguous() for a in initial_args]
|
||||
for name, value in zip(input_names, args):
|
||||
dev = value.device
|
||||
binding.bind_input(
|
||||
|
|
@ -228,7 +231,7 @@ def onnxrt(subgraph):
|
|||
|
||||
@functools.lru_cache(None)
|
||||
def _init_tensorflow():
|
||||
import tensorflow as tf
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
# prevent tensorflow from eating all the GPU memory
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
|
|
@ -239,8 +242,8 @@ def _init_tensorflow():
|
|||
|
||||
@create_backend
|
||||
def onnx2tf(subgraph):
|
||||
import onnx
|
||||
from onnx_tf.backend import prepare
|
||||
import onnx # type: ignore[import]
|
||||
from onnx_tf.backend import prepare # type: ignore[import]
|
||||
|
||||
tf = _init_tensorflow()
|
||||
filename = subgraph.filename("tensorflow")
|
||||
|
|
@ -253,8 +256,8 @@ def onnx2tf(subgraph):
|
|||
tf_module = tf.saved_model.load(filename)
|
||||
tf_module = tf.function(tf_module, jit_compile=True)
|
||||
|
||||
def run(*args):
|
||||
args = [a.contiguous() for a in args]
|
||||
def run(*i_args):
|
||||
args = [a.contiguous() for a in i_args]
|
||||
with tf.device(device):
|
||||
outs = tf_module(
|
||||
**{
|
||||
|
|
@ -292,7 +295,7 @@ def taso(subgraph):
|
|||
|
||||
@create_backend
|
||||
def ipex(subgraph, **kwargs):
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import intel_extension_for_pytorch as ipex # type: ignore[import]
|
||||
|
||||
inputs = subgraph.example_inputs
|
||||
model = subgraph.model
|
||||
|
|
@ -321,12 +324,20 @@ def fx2trt(subgraph, **kwargs):
|
|||
# TensorRT fails violently with an abort() on this
|
||||
return None
|
||||
|
||||
from torch_tensorrt.fx.fx2trt import InputTensorSpec, TRTInterpreter
|
||||
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
|
||||
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
|
||||
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
|
||||
from torch_tensorrt.fx.trt_module import TRTModule
|
||||
from torch_tensorrt.fx.utils import LowerPrecision
|
||||
from torch_tensorrt.fx.fx2trt import ( # type: ignore[import]
|
||||
InputTensorSpec,
|
||||
TRTInterpreter,
|
||||
)
|
||||
from torch_tensorrt.fx.passes.lower_basic_pass import ( # type: ignore[import]
|
||||
transform_setitem,
|
||||
)
|
||||
from torch_tensorrt.fx.tools.trt_splitter import ( # type: ignore[import]
|
||||
TRTSplitter,
|
||||
TRTSplitterSetting,
|
||||
)
|
||||
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer # type: ignore[import]
|
||||
from torch_tensorrt.fx.trt_module import TRTModule # type: ignore[import]
|
||||
from torch_tensorrt.fx.utils import LowerPrecision # type: ignore[import]
|
||||
|
||||
from .normalize import normalize_ir
|
||||
|
||||
|
|
@ -414,7 +425,7 @@ def torch2trt(subgraph):
|
|||
# TensorRT fails violently with an abort() on this
|
||||
return None
|
||||
|
||||
from torch2trt import torch2trt
|
||||
from torch2trt import torch2trt # type: ignore[import]
|
||||
|
||||
inputs = subgraph.example_inputs
|
||||
trt_mod = torch2trt(
|
||||
|
|
@ -438,45 +449,6 @@ def tensorrt(subgraph):
|
|||
return model
|
||||
|
||||
|
||||
@create_backend
|
||||
def onnx2tensorrt_alt(subgraph):
|
||||
if subgraph.will_tensorrt_barf():
|
||||
# TensorRT fails violently with an abort() on this
|
||||
return None
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
from torch.fx.experimental.fx2trt.trt_module import TRTModule
|
||||
|
||||
inputs = subgraph.example_inputs
|
||||
|
||||
logger = trt.Logger(trt.Logger.ERROR)
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
assert isinstance(inputs, (list, tuple))
|
||||
inputs = tuple(inputs)
|
||||
input_names = subgraph.input_names
|
||||
output_names = subgraph.output_names
|
||||
network = builder.create_network(
|
||||
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
success = parser.parse(open(subgraph.onnx_filename, "rb").read())
|
||||
for idx in range(parser.num_errors):
|
||||
print(parser.get_error(idx))
|
||||
assert success
|
||||
|
||||
config.max_workspace_size = 1 << 25
|
||||
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
||||
builder.max_batch_size = len(inputs[0])
|
||||
|
||||
engine = builder.build_engine(network, config)
|
||||
assert engine
|
||||
|
||||
trt_mod = TRTModule(engine, input_names, output_names)
|
||||
return subgraph.wrap_returns(trt_mod)
|
||||
|
||||
|
||||
@create_backend
|
||||
def cudagraphs(subgraph):
|
||||
model = subgraph.model
|
||||
|
|
@ -628,9 +600,9 @@ def tvm_compile_inner(
|
|||
jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False
|
||||
):
|
||||
try:
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.contrib import graph_executor
|
||||
import tvm # type: ignore[import]
|
||||
from tvm import relay # type: ignore[import]
|
||||
from tvm.contrib import graph_executor # type: ignore[import]
|
||||
|
||||
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
|
||||
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
|
||||
|
|
@ -724,8 +696,8 @@ def tvm_compile_inner(
|
|||
return torch.from_numpy(nd_tensor.numpy())
|
||||
return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
|
||||
|
||||
def exec_tvm(*args):
|
||||
args = [a.contiguous() for a in args]
|
||||
def exec_tvm(*i_args):
|
||||
args = [a.contiguous() for a in i_args]
|
||||
for idx, arg in enumerate(args, 0):
|
||||
if arg.dim() != 0:
|
||||
if arg.requires_grad:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class AotNop(AotAutogradStrategy):
|
|||
DEBUG = False
|
||||
return BACKENDS["aot_autograd"](
|
||||
self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_eager = AotNop.compile_fn
|
||||
|
|
@ -164,7 +164,7 @@ class AotTorchscript(AotAutogradStrategy):
|
|||
|
||||
return BACKENDS["aot_autograd"](
|
||||
self.gm, self.example_inputs, fw_compiler=ts_compile
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_ts = AotTorchscript.compile_fn
|
||||
|
|
@ -214,7 +214,7 @@ class AotMemEfficientFusion(AotAutogradStrategy):
|
|||
|
||||
def candidate(self):
|
||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=True)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
|
||||
|
|
@ -222,7 +222,7 @@ class AotMemEfficientFusionNoDecomps(AotAutogradStrategy):
|
|||
|
||||
def candidate(self):
|
||||
kwargs = mem_efficient_fusion_kwargs(use_decomps=False)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class AotInductorDebug(AotAutogradStrategy):
|
||||
|
|
@ -247,7 +247,7 @@ class AotInductorDebug(AotAutogradStrategy):
|
|||
min_cut_rematerialization_partition, compiler="inductor"
|
||||
),
|
||||
}
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs)
|
||||
return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
|
||||
aot_inductor_debug = AotInductorDebug.compile_fn
|
||||
|
|
@ -346,7 +346,7 @@ def create_nvprims_backend(*, executor):
|
|||
fw_compiler=partial(prims_executor, executor=self.executor),
|
||||
bw_compiler=partial(prims_executor, executor=self.executor),
|
||||
partition_fn=disable(nvprims_fw_bw_partition_fn),
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
return NvPrims
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ import operator
|
|||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Set, Tuple, Union
|
||||
|
||||
import sympy
|
||||
from typing_extensions import Protocol
|
||||
|
||||
import torch.nn
|
||||
from torch import fx
|
||||
|
|
@ -16,8 +19,8 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|||
from . import config, logging as torchdynamo_logging, variables
|
||||
from .bytecode_transformation import create_instruction, Instruction, unique_id
|
||||
from .codegen import PyCodegen
|
||||
from .exc import BackendCompilerFailed
|
||||
from .guards import GuardBuilder
|
||||
from .exc import BackendCompilerFailed, unimplemented
|
||||
from .guards import Guard, GuardBuilder, TensorReference
|
||||
from .mutation_guard import is_dynamic_nn_module
|
||||
from .side_effects import SideEffects
|
||||
from .source import ConstantSource, LocalSource, Source
|
||||
|
|
@ -31,7 +34,8 @@ from .utils import (
|
|||
format_graph_tabular,
|
||||
same,
|
||||
)
|
||||
from .variables.builder import VariableBuilder, wrap_fx_proxy
|
||||
from .variables.base import VariableTracker
|
||||
from .variables.builder import GraphArg, VariableBuilder, wrap_fx_proxy
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
DynamicShapeVariable,
|
||||
|
|
@ -43,6 +47,15 @@ from .variables.tensor import (
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: I think this accepts int arguments too
|
||||
class CompiledFn(Protocol):
|
||||
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||
...
|
||||
|
||||
|
||||
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _step_logger():
|
||||
return torchdynamo_logging.get_step_logger(log)
|
||||
|
|
@ -75,27 +88,27 @@ class FakeRootModule(torch.nn.Module):
|
|||
return "FakeRootModule(...)"
|
||||
|
||||
|
||||
def wrap_compiler_fn(compiler_fn):
|
||||
def wrap_compiler_fn(compiler_fn: CompilerFn) -> CompilerFn:
|
||||
"""WrapperBackend if config.verify_correctness is True"""
|
||||
if config.verify_correctness:
|
||||
# wrap backend if verify_correctness is True
|
||||
wrapper_backend_compiler_fn = WrapperBackend(compiler_fn)
|
||||
|
||||
wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn
|
||||
wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||
return wrapper_backend_compiler_fn
|
||||
|
||||
return compiler_fn
|
||||
|
||||
|
||||
class WrapperBackend:
|
||||
def __init__(self, backend=None):
|
||||
self.backend = backend
|
||||
def __init__(self, backend: CompilerFn):
|
||||
self.backend: CompilerFn = backend
|
||||
|
||||
@property
|
||||
def example_inputs(self):
|
||||
return clone_inputs(self.original_example_inputs)
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
|
||||
self.restore = checkpoint_params(gm)
|
||||
self.original_example_inputs = clone_inputs(example_inputs)
|
||||
|
|
@ -138,38 +151,42 @@ class OutputGraph(fx.Tracer):
|
|||
self,
|
||||
f_globals: Dict[str, Any],
|
||||
code_options: Dict[str, Any],
|
||||
compiler_fn: Callable,
|
||||
compiler_fn: CompilerFn,
|
||||
root_tx,
|
||||
):
|
||||
super(OutputGraph, self).__init__()
|
||||
|
||||
# Mutable state checkpointed by copy_graphstate()
|
||||
self.graph = torch.fx.Graph()
|
||||
self.graphargs = []
|
||||
self.guards = set()
|
||||
self.nn_modules = dict()
|
||||
self.graphargs: List[GraphArg] = []
|
||||
self.guards: Set[Guard] = set()
|
||||
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
|
||||
self.side_effects = SideEffects()
|
||||
self.code_options = dict(code_options)
|
||||
self.output_instructions = []
|
||||
self.output_instructions: List[Instruction] = []
|
||||
# Node => computed real value (see utils.get_real_value)
|
||||
self.real_value_cache = {}
|
||||
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
|
||||
|
||||
# Not checkpointed
|
||||
self.compiler_fn = compiler_fn
|
||||
self.compiler_fn: CompilerFn = compiler_fn
|
||||
self.root_globals = f_globals
|
||||
self.root_tx = root_tx
|
||||
self.cleanups = []
|
||||
self.cleanups: List[CleanupHook] = []
|
||||
self.should_exit = False
|
||||
self.random_values_var = None
|
||||
self.initial_random_state = ()
|
||||
self.unspec_variable_map = {}
|
||||
self.unspec_variable_map: Dict[
|
||||
str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable]
|
||||
] = {}
|
||||
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
|
||||
self.tensor_id_to_sym_shape_ref = {}
|
||||
self.intermediary_symbols = {}
|
||||
self.tensor_id_to_sym_shape_ref: Dict[int, Set[TensorReference]] = {}
|
||||
self.intermediary_symbols: Dict[sympy.Expr, None] = {}
|
||||
|
||||
# Enables creating unique node names by tracking
|
||||
# all current placeholder node names
|
||||
self.name_to_input = collections.OrderedDict()
|
||||
self.name_to_input: OrderedDict[
|
||||
str, Optional[fx.Proxy]
|
||||
] = collections.OrderedDict()
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
|
|
@ -181,6 +198,7 @@ class OutputGraph(fx.Tracer):
|
|||
|
||||
def copy_graphstate(self):
|
||||
"""Create a checkpoint of the current state by copying everything"""
|
||||
assert self.nn_modules is not None
|
||||
graph_nodes = set(self.graph.nodes)
|
||||
return (
|
||||
graph_nodes,
|
||||
|
|
@ -313,6 +331,7 @@ class OutputGraph(fx.Tracer):
|
|||
target
|
||||
)
|
||||
|
||||
assert self.nn_modules is not None
|
||||
for k, v in self.nn_modules.items():
|
||||
if v is target:
|
||||
# it already exists
|
||||
|
|
@ -356,11 +375,14 @@ class OutputGraph(fx.Tracer):
|
|||
|
||||
tx.prune_dead_locals()
|
||||
stack_values = list(tx.stack)
|
||||
assert self.nn_modules is not None
|
||||
root = FakeRootModule(self.nn_modules)
|
||||
|
||||
# Add all the local vars to the "stack" so restore at the end
|
||||
restore_vars = []
|
||||
val_to_names = collections.OrderedDict()
|
||||
val_to_names: OrderedDict[
|
||||
VariableTracker, List[str]
|
||||
] = collections.OrderedDict()
|
||||
if stack_values:
|
||||
val_to_names[stack_values[-1]] = list()
|
||||
for k, v in tx.symbolic_locals.items():
|
||||
|
|
@ -494,7 +516,7 @@ class OutputGraph(fx.Tracer):
|
|||
# the call to tabulate can cause a lot of memory to be allocated
|
||||
if config.log_level <= logging.INFO:
|
||||
log.log(
|
||||
logging.CODE,
|
||||
logging.CODE, # type: ignore[attr-defined]
|
||||
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n",
|
||||
)
|
||||
except ImportError:
|
||||
|
|
@ -508,7 +530,7 @@ class OutputGraph(fx.Tracer):
|
|||
cg.make_call_generated_code(name)
|
||||
return cg.get_instructions()
|
||||
|
||||
def call_user_compiler(self, gm):
|
||||
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
|
||||
try:
|
||||
name = (
|
||||
self.compiler_fn.__name__
|
||||
|
|
@ -527,13 +549,13 @@ class OutputGraph(fx.Tracer):
|
|||
raise BackendCompilerFailed(self.compiler_fn, e) from e
|
||||
return compiled_fn
|
||||
|
||||
def example_inputs(self):
|
||||
def example_inputs(self) -> List[torch.Tensor]:
|
||||
result = []
|
||||
for arg in self.graphargs:
|
||||
result.extend(arg.get_examples())
|
||||
return result
|
||||
|
||||
def remove_unused_graphargs(self):
|
||||
def remove_unused_graphargs(self) -> None:
|
||||
for node in reversed(list(self.graph.nodes)):
|
||||
if len(list(node.users)) == 0:
|
||||
if node.op == "get_attr":
|
||||
|
|
@ -560,7 +582,7 @@ class OutputGraph(fx.Tracer):
|
|||
|
||||
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
|
||||
|
||||
def add_output_instructions(self, prefix: List[Instruction]):
|
||||
def add_output_instructions(self, prefix: List[Instruction]) -> None:
|
||||
"""
|
||||
We call this on the creation of a new compiled subgraph that is inserted
|
||||
before user code.
|
||||
|
|
@ -568,10 +590,10 @@ class OutputGraph(fx.Tracer):
|
|||
self.output_instructions.extend(prefix)
|
||||
self.should_exit = True
|
||||
|
||||
def install_global(self, name, value):
|
||||
def install_global(self, name, value) -> None:
|
||||
self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
# There is a reference cycle between tracer and OutputGraph, causing
|
||||
# some of the tensor objects to be held alive for longer than necessary.
|
||||
|
||||
|
|
@ -620,7 +642,8 @@ class OutputGraph(fx.Tracer):
|
|||
frame_summaries.append(tx.frame_summary())
|
||||
tx = getattr(tx, "parent", None)
|
||||
|
||||
msgs = traceback.StackSummary.from_list(frame_summaries).format()
|
||||
# official from_list stub doesn't have new-style type
|
||||
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
|
||||
|
||||
# Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra
|
||||
nn_module_stack_str = f"Module stack: {nn_module_stack}\n"
|
||||
|
|
|
|||
27
torch/_dynamo/types.py
Normal file
27
torch/_dynamo/types.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import types
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch._dynamo.guards import GuardedCode
|
||||
|
||||
|
||||
class DynamoCallbackFn(Protocol):
|
||||
def __call__(
|
||||
self, frame: types.FrameType, cache_size: int
|
||||
) -> Optional[GuardedCode]:
|
||||
...
|
||||
|
||||
|
||||
DynamoCallback = Union[DynamoCallbackFn, None, bool]
|
||||
|
||||
|
||||
class DynamoGuardHook(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
guard_fn: Callable,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, Any],
|
||||
last: bool,
|
||||
) -> None:
|
||||
...
|
||||
Loading…
Reference in New Issue
Block a user