mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This work rewrites vmap support in torch.compile by inlining most of the frames into the existing FX graph. It also unlocks to PyTorch to support features that were previously missing, such as keyword args. Fixes: https://github.com/pytorch/pytorch/issues/114306 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050 Approved by: https://github.com/zou3519
432 lines
13 KiB
Python
432 lines
13 KiB
Python
import _collections_abc
|
|
import _weakrefset
|
|
import abc
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import copyreg
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import linecache
|
|
import logging
|
|
import multiprocessing
|
|
import operator
|
|
import os
|
|
import posixpath
|
|
import random
|
|
import re
|
|
import selectors
|
|
import signal
|
|
import tempfile
|
|
import threading
|
|
import tokenize
|
|
import traceback
|
|
import types
|
|
import typing
|
|
import unittest
|
|
import weakref
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch._inductor.test_operators
|
|
import torch.distributed
|
|
import torch.utils._content_store
|
|
from .utils import getfile
|
|
|
|
from .variables.functions import (
|
|
NestedUserFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
)
|
|
|
|
|
|
"""
|
|
A note on skipfiles:
|
|
|
|
Dynamo consults this file to determine whether function should be inlined or skipped.
|
|
|
|
A skip applies at the frame boundary, meaning dynamo either triggers a graph break
|
|
at the beginning of the frame or attempts to trace/inline the whole frame. When skipping
|
|
a frame, recursively called frames are still traced by dynamo unless also skipped.
|
|
|
|
Skipfiles (skipped at the file level instead of function level) still apply on a
|
|
frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
|
|
|
|
@skip is a helper decorator that can be applied to your function to cause it to be
|
|
included here.
|
|
|
|
Dynamo skip/inline rules & priorities are defined as follows:
|
|
* Inline is the default behavior and will be used unless explicitly skipped.
|
|
* Dynamo has two SKIPLIST: BUILTIN_SKIPLIST and THIRDPARTY_SKIPLIST.
|
|
* BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc.
|
|
* THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc.
|
|
* Functions in these two SKIPLISTs are always skipped, except when they are explicitly
|
|
put into the two INLINELIST: FUNC_INLINELIST and MOD_INLINELIST.
|
|
* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases
|
|
where we want inline the functions under torch namespace. We should add them
|
|
into one of the two *_INLINELIST to make dynamo inline those functions.
|
|
* If you call functions under skipped modules/files, Dynamo will wrap these functions
|
|
as SkipFilesVariable. There are a few functions(e.g, collections.OrderedDict) that
|
|
we have special handling at SkipFilesVariable.call_function.
|
|
|
|
Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline)
|
|
|
|
To figure out what the behavior is, check the following list in order:
|
|
* FUNC_INLINELIST (Inline if YES)
|
|
* MOD_INLINELIST (Inline if YES)
|
|
* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES)
|
|
* Inline by default
|
|
|
|
In general, if you want to force inline a function or module, please consider adding
|
|
the function's python module to MOD_INLINELIST first.
|
|
Use the FUNC_INLINELIST only when there are other functions under the same module that
|
|
you don't want to inline them.
|
|
"""
|
|
|
|
|
|
BUILTIN_SKIPLIST = (
|
|
abc,
|
|
collections,
|
|
contextlib,
|
|
copy,
|
|
copyreg,
|
|
dataclasses,
|
|
enum,
|
|
functools,
|
|
importlib,
|
|
inspect,
|
|
linecache,
|
|
logging,
|
|
multiprocessing,
|
|
operator,
|
|
os,
|
|
posixpath,
|
|
random,
|
|
re,
|
|
selectors,
|
|
signal,
|
|
tempfile,
|
|
threading,
|
|
tokenize,
|
|
torch, # torch/* is skipped by default unless specified in FUNC_INLINELIST or MOD_INLINELIST
|
|
traceback,
|
|
types,
|
|
typing,
|
|
unittest,
|
|
weakref,
|
|
_collections_abc,
|
|
_weakrefset,
|
|
)
|
|
|
|
# third party libraries skiplist is defined by str, because users may not use these libraries.
|
|
# we should use lazy import & skip in the future.
|
|
THIRDPARTY_SKIPLIST = (
|
|
"fx2trt_oss",
|
|
"networkx",
|
|
"numpy",
|
|
"omegaconf",
|
|
"onnx",
|
|
"onnxruntime",
|
|
"onnx_tf",
|
|
"pandas",
|
|
"sklearn",
|
|
"tabulate",
|
|
"tensorflow",
|
|
"tensorrt",
|
|
"torch2trt",
|
|
"tqdm",
|
|
"tree",
|
|
"tvm",
|
|
"xarray",
|
|
)
|
|
|
|
|
|
def _strip_init_py(s):
|
|
return re.sub(r"__init__.py$", "", s)
|
|
|
|
|
|
def _module_dir(m: types.ModuleType):
|
|
return _strip_init_py(m.__file__)
|
|
|
|
|
|
# TODO: Add a decoractor for easily adding functions to FUNC_INLINELIST
|
|
# after resolving all circular import issues.
|
|
FUNC_INLINELIST = {
|
|
"torch._constrain_as_size",
|
|
"torch._constrain_as_value",
|
|
"torch._tensor._convert",
|
|
}
|
|
|
|
|
|
# These are legacy workarounds, don't add new modules to this list.
|
|
# Please use the MOD_INLINELIST instead to force inline functions under particular modules.
|
|
LEGACY_MOD_INLINELIST = {
|
|
"torch._dynamo.external_utils",
|
|
"torch._export.db.examples",
|
|
"torch._export.wrappers",
|
|
"torch._functorch.apis",
|
|
"torch._functorch.deprecated",
|
|
"torch._higher_order_ops.cond",
|
|
"torch.ao.quantization.pt2e.eval_utils",
|
|
"torch.ao.quantization.pt2e.qat_utils",
|
|
"torch.ao.quantization.pt2e.representation.rewrite",
|
|
"torch.ao.quantization.pt2e.utils",
|
|
"torch.ao.quantization.quantizer.xnnpack_quantizer",
|
|
"torch.optim",
|
|
}
|
|
|
|
if torch.distributed.is_available():
|
|
LEGACY_MOD_INLINELIST |= {
|
|
"torch.distributed._tensor.api",
|
|
"torch.distributed._tensor.device_mesh",
|
|
"torch.distributed.device_mesh",
|
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
|
"torch.distributed.tensor.parallel._data_parallel_utils",
|
|
"torch.distributed.tensor.parallel._utils",
|
|
"torch.distributed.tensor.parallel.style",
|
|
}
|
|
|
|
|
|
# Force inline functions under these modules, even they are in *_SKIPLIST.
|
|
# We are using python module name instead of file or directory object to avoid circular dependency.
|
|
# Please keep this sorted alphabetically.
|
|
MOD_INLINELIST = {
|
|
"torch._refs",
|
|
"torch._prims",
|
|
"torch._decomp",
|
|
"torch._dynamo._trace_wrapped_higher_order_op",
|
|
"torch._dynamo.comptime",
|
|
"torch._dynamo.polyfill",
|
|
"torch._functorch.vmap",
|
|
"torch._inductor.test_operators",
|
|
"torch.amp.autocast_mode",
|
|
"torch.ao.nn",
|
|
"torch.autograd.function",
|
|
"torch.cuda.amp.autocast_mode",
|
|
"torch.distributions",
|
|
"torch.export.wrapper",
|
|
"torch.fx._pytree",
|
|
"torch.fx.passes.shape_prop",
|
|
"torch.nn",
|
|
"torch.random",
|
|
"torch.sparse",
|
|
"torch.testing",
|
|
"torch.utils._content_store",
|
|
"torch.utils._contextlib",
|
|
"torch.utils._foreach_utils",
|
|
"torch.utils._pytree",
|
|
"torch._tensor",
|
|
"torch._higher_order_ops.strict_mode",
|
|
}
|
|
|
|
|
|
if torch.distributed.is_available():
|
|
MOD_INLINELIST.add("torch.distributed")
|
|
MOD_INLINELIST.add("torch.distributed._functional_collectives")
|
|
|
|
|
|
# TODO: support adding bound method into this list
|
|
@functools.lru_cache(None)
|
|
def get_func_inlinelist():
|
|
inlinelist = set()
|
|
for f in FUNC_INLINELIST:
|
|
module_name, fn_name = f.rsplit(".", 1)
|
|
m = importlib.import_module(module_name)
|
|
fn = getattr(m, fn_name)
|
|
inlinelist.add(fn.__code__)
|
|
return inlinelist
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_legacy_mod_inlinelist():
|
|
inlinelist = set()
|
|
for m in LEGACY_MOD_INLINELIST:
|
|
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
|
return inlinelist
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_mod_inlinelist():
|
|
inlinelist = set()
|
|
for m in MOD_INLINELIST:
|
|
inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/"))
|
|
return inlinelist
|
|
|
|
|
|
# skip some standard python builtin libs
|
|
SKIP_DIRS = [
|
|
"<frozen importlib",
|
|
"<__array_function__ internals>",
|
|
] + [_module_dir(m) for m in BUILTIN_SKIPLIST]
|
|
|
|
SKIP_DIRS_RE = re.compile(r"match nothing^")
|
|
|
|
is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
|
|
# Skip fbcode paths(including torch.package paths) containing
|
|
# one of the following strings.
|
|
FBCODE_SKIP_DIRS = {
|
|
"torchrec/distributed",
|
|
"torchrec/fb/distributed",
|
|
"caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py",
|
|
}
|
|
FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})")
|
|
|
|
|
|
def _recompile_re():
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
|
|
|
|
|
|
def add(import_name: str):
|
|
if isinstance(import_name, types.ModuleType):
|
|
return add(import_name.__name__)
|
|
assert isinstance(import_name, str)
|
|
from importlib.util import find_spec
|
|
|
|
module_spec = find_spec(import_name)
|
|
if not module_spec:
|
|
return
|
|
origin = module_spec.origin
|
|
if origin is None:
|
|
return
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS.append(_strip_init_py(origin))
|
|
_recompile_re()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SkipResult:
|
|
skipped: bool
|
|
reason: Optional[str]
|
|
|
|
|
|
def check_file(filename, is_inlined_call=False):
|
|
"""Should skip this file?"""
|
|
if filename is None:
|
|
return SkipResult(True, "filename is None")
|
|
if any(filename.startswith(d) for d in get_legacy_mod_inlinelist()):
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.LEGACY_MOD_INLINELIST",
|
|
)
|
|
if is_inlined_call and is_torch_inline_allowed(filename):
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.MOD_INLINELIST",
|
|
)
|
|
if is_fbcode and bool(FBCODE_SKIP_DIRS_RE.match(filename)):
|
|
return SkipResult(
|
|
True,
|
|
"skipped according skipfiles.FBCODE_SKIP_DIRS",
|
|
)
|
|
if bool(SKIP_DIRS_RE.match(filename)):
|
|
return SkipResult(True, "skipped according skipfiles.SKIP_DIRS")
|
|
else:
|
|
return SkipResult(False, "inlined by default")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class FunctionInfo:
|
|
py_obj: Optional[object]
|
|
name: Optional[str]
|
|
filename: str
|
|
code: Optional[types.CodeType]
|
|
|
|
|
|
"""
|
|
This is the main entry point to determine whether an object (function) should be inlined or skipped.
|
|
Let's illustrate the logic with an example:
|
|
@torch.compile
|
|
def f1(x, y):
|
|
......
|
|
f2(x, y)
|
|
......
|
|
|
|
def f2(x, y):
|
|
......
|
|
f3(x, y)
|
|
......
|
|
|
|
def f3(x, y):
|
|
......
|
|
|
|
There are mainly three call sites of check/check_verbose:
|
|
* The compile region entrance (like function f1), the correspoinding code is located at eval_frame.py.
|
|
* When tracing the recursively called functions (like function f2 and f3).
|
|
* Dynamo decides inline/skip everytime it encounters a new recursively function call, and the call site
|
|
is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py.
|
|
* If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again
|
|
and the call site is in catch_errors_wrapper.catch_errors of eval_frame.py.
|
|
* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFilesVariable in builder.py.
|
|
|
|
`is_inlined_call` is used to indicate if the current function call is inlined (f2 is inlined call if it passes check)
|
|
or not (f3 is not inlined call if f2 is skipped). Inside of the `check_verbose` function, there are more rules
|
|
to be checked if this `is_inlined_call`.
|
|
The reason to have this flag is that if the upper level function call (e.g, f2) is skipped,
|
|
we don't want to inline the lower level function call (e.g, f3) by default.
|
|
"""
|
|
|
|
|
|
def check_verbose(obj, is_inlined_call=False):
|
|
if isinstance(
|
|
obj, (UserFunctionVariable, UserMethodVariable, NestedUserFunctionVariable)
|
|
):
|
|
try:
|
|
py_obj = obj.get_function()
|
|
except NotImplementedError:
|
|
py_obj = None
|
|
fi = FunctionInfo(py_obj, obj.get_name(), obj.get_filename(), obj.get_code())
|
|
elif isinstance(obj, types.CodeType):
|
|
fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj)
|
|
elif isinstance(obj, (types.FunctionType, types.MethodType)):
|
|
fi = FunctionInfo(
|
|
obj, obj.__name__, getfile(obj), obj.__code__ # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed
|
|
)
|
|
else:
|
|
fi = FunctionInfo(obj, None, getfile(obj), None)
|
|
# Go through function based skip/inline rules.
|
|
if fi.code in get_func_inlinelist():
|
|
return SkipResult(
|
|
False,
|
|
"inlined according skipfiles.FUNC_INLINELIST",
|
|
)
|
|
if is_inlined_call:
|
|
if fi.name == "patched_init":
|
|
return SkipResult(True, "patched init cannot be inlined.")
|
|
elif fi.name == "__torch_function__":
|
|
return SkipResult(False, "allow inlining __torch_function__")
|
|
|
|
# Go through file based skip/inline rules.
|
|
return check_file(fi.filename, is_inlined_call)
|
|
|
|
|
|
def check(obj, is_inlined_call=False):
|
|
return check_verbose(obj, is_inlined_call).skipped
|
|
|
|
|
|
# skip common third party libs
|
|
for _name in THIRDPARTY_SKIPLIST:
|
|
add(_name)
|
|
|
|
_recompile_re()
|
|
|
|
|
|
def is_torch_inline_allowed(filename):
|
|
return any(filename.startswith(d) for d in get_mod_inlinelist())
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def dynamo_dir():
|
|
import torch._dynamo
|
|
|
|
return _module_dir(torch._dynamo)
|
|
|
|
|
|
def is_torch(filename):
|
|
if filename.startswith(dynamo_dir()):
|
|
return False
|
|
return filename.startswith(_module_dir(torch))
|