mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120455 Approved by: https://github.com/albanD
424 lines
16 KiB
Python
424 lines
16 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import dataclasses
|
|
import importlib
|
|
import inspect
|
|
import math
|
|
import types
|
|
import unittest
|
|
import warnings
|
|
from typing import Any, Dict, Set
|
|
|
|
import torch
|
|
import torch._dynamo.config as config
|
|
import torch._dynamo.test_case
|
|
import torch._functorch.deprecated as deprecated_func
|
|
from torch._dynamo.trace_rules import (
|
|
LEGACY_MOD_INLINELIST,
|
|
load_object,
|
|
manual_torch_name_rule_map,
|
|
MOD_INLINELIST,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
)
|
|
from torch._dynamo.utils import hashable, is_safe_constant, istype
|
|
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
|
|
|
|
try:
|
|
from .utils import create_dummy_module_and_function
|
|
except ImportError:
|
|
from utils import create_dummy_module_and_function
|
|
|
|
|
|
ignored_c_binding_in_graph_function_names = {
|
|
# Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`.
|
|
"torch._nested_tensor_from_mask",
|
|
"torch._nested_from_padded",
|
|
# Ignored and go through rules defined at `trace_rules.check`.
|
|
"torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode",
|
|
"torch._cslt_sparse_mm_search",
|
|
"torch._C._abort",
|
|
"torch._C._mps_is_on_macos_or_newer",
|
|
"torch._C._swap_tensor_impl",
|
|
"torch._C._unsafe_reset_storage",
|
|
"torch._dynamo.eval_frame.reset_code",
|
|
"torch._C.autocast_decrement_nesting",
|
|
"torch._C.autocast_increment_nesting",
|
|
"torch._C.clear_autocast_cache",
|
|
"torch._C.set_anomaly_enabled",
|
|
"torch._C.set_autocast_cache_enabled",
|
|
"torch._C.set_autocast_cpu_dtype",
|
|
"torch._C.set_autocast_cpu_enabled",
|
|
"torch._C.set_autocast_enabled",
|
|
"torch._C.set_autocast_gpu_dtype",
|
|
"torch._C.set_autocast_ipu_dtype",
|
|
"torch._C.set_autocast_ipu_enabled",
|
|
"torch._C.set_autocast_xla_dtype",
|
|
"torch._C.set_autocast_xla_enabled",
|
|
"torch.resize_as_",
|
|
"torch.resize_as_sparse_",
|
|
"torch._C._data_address",
|
|
"torch._C._is_cow_tensor",
|
|
"torch._lazy_clone",
|
|
"torch._test_parallel_materialize",
|
|
"torch._C._storage_address",
|
|
"torch._C._pickle_save",
|
|
"torch.cuda._get_device_properties",
|
|
}
|
|
if torch._C._llvm_enabled():
|
|
ignored_c_binding_in_graph_function_names |= {
|
|
"torch._C._te.set_llvm_aot_workflow",
|
|
"torch._C._te.set_llvm_target_cpu",
|
|
"torch._C._te.set_llvm_target_attrs",
|
|
"torch._C._te.set_llvm_target_triple",
|
|
}
|
|
|
|
|
|
# Helper function to dump the torch name rule map generated based on
|
|
# the heuristic defined in gen_allowed_objs_and_ids.
|
|
def dump_allowed_torch_name_rule_map() -> None:
|
|
m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
|
|
for k, v in m.items():
|
|
print(f'"{k}": {v.__name__},')
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AllowedObjects:
|
|
"""
|
|
Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
|
|
from the heuristic defined in `gen_allowed_objs_and_ids`.
|
|
"""
|
|
|
|
object_ids: Dict[int, str]
|
|
c_binding_in_graph_functions: Set[Any]
|
|
non_c_binding_in_graph_functions: Set[Any]
|
|
name_rule_map: Dict[str, Any]
|
|
|
|
|
|
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
|
|
"""
|
|
Walk torch.* and get the ids of all the stuff in it
|
|
"""
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
|
|
torch_object_ids = dict()
|
|
c_binding_in_graph_functions = set()
|
|
non_c_binding_in_graph_functions = set()
|
|
torch_name_rule_map = dict()
|
|
|
|
# In some platforms, these functions were loaded as classes instead of functions.
|
|
# To mitigate these weired cases, we need this special check.
|
|
def is_special_functions(obj):
|
|
return hashable(obj) and obj in {
|
|
torch._C._cuda_isCurrentStreamCapturing,
|
|
torch._C._graph_pool_handle,
|
|
}
|
|
|
|
# Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
|
|
# if it's a torch function or method.
|
|
# This is used to generate the in graph function list based on heuristic.
|
|
def heuristic_record_if_in_graph_function(obj, module, name):
|
|
try:
|
|
if hasattr(obj, "__wrapped__"):
|
|
obj = obj.__wrapped__
|
|
except Exception:
|
|
pass
|
|
if isinstance(
|
|
obj,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
types.MethodDescriptorType,
|
|
types.WrapperDescriptorType,
|
|
),
|
|
) or is_special_functions(obj):
|
|
torch_name_rule_map[
|
|
f"{module.__name__}.{name}"
|
|
] = TorchInGraphFunctionVariable
|
|
if c_binding_only:
|
|
if not hasattr(obj, "__code__"):
|
|
c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
if hasattr(obj, "__code__"):
|
|
non_c_binding_in_graph_functions.add(obj)
|
|
else:
|
|
c_binding_in_graph_functions.add(obj)
|
|
|
|
def _is_allowed_module_prefix(obj):
|
|
allowed_modules = ("torch", "math")
|
|
# torch.nn.modules.rnn is disallowed because these modules internally
|
|
# flatten their parameters. This flattening process will call
|
|
# Tensor.set_ with a Storage, and Storages cannot be traced with
|
|
# AOTAutograd; so we need to graph-break. To ensure this, we inline
|
|
# these functions, rather than keep them opaque-ly in the graph.
|
|
disallowed_modules = [
|
|
"torch.optim.",
|
|
"torch.nn.modules.rnn.",
|
|
"torch._dynamo.",
|
|
"torch._C._dynamo.",
|
|
"torch._inductor.",
|
|
"torch._C.inductor.",
|
|
"torch.fx.",
|
|
"torch._C._autograd",
|
|
"torch._C._cudart",
|
|
"torch._C._distributed_autograd",
|
|
"torch._C._distributed_c10d",
|
|
"torch._C._distributed_rpc",
|
|
"torch._C._functorch",
|
|
"torch._C._monitor",
|
|
"torch._C._nvtx",
|
|
"torch._C._lazy",
|
|
"torch._C._profiler",
|
|
"torch.__config__",
|
|
"torch._custom_op",
|
|
"torch._decomp",
|
|
"torch._dispatch",
|
|
"torch._export",
|
|
"torch._functorch.make_functional",
|
|
"torch._functorch.compile_utils",
|
|
"torch._functorch.partitioners",
|
|
"torch._functorch.aot_autograd",
|
|
"torch._functorch.compilers",
|
|
"torch._functorch.fx_minifier",
|
|
"torch.autograd.profiler_util",
|
|
"torch.autograd.profiler",
|
|
"torch._jit_internal",
|
|
"torch._library",
|
|
"torch._lobpcg",
|
|
"torch._logging",
|
|
"torch._meta_registrations",
|
|
"torch._namedtensor_internals",
|
|
"torch._numpy",
|
|
"torch._sources",
|
|
"torch._subclasses",
|
|
"torch._tensor",
|
|
"torch._tensor_str",
|
|
"torch._utils",
|
|
"torch._utils_internal",
|
|
"torch._vmap_internals",
|
|
"torch.compiler",
|
|
"torch.distributed",
|
|
"torch.export",
|
|
"torch.hub",
|
|
"torch.jit",
|
|
"torch.library",
|
|
"torch.masked.maskedtensor",
|
|
"torch.nn.init",
|
|
"torch.nn.modules.module",
|
|
"torch.nn.parallel",
|
|
"torch.nn.utils",
|
|
"torch.multiprocessing",
|
|
"torch.onnx",
|
|
"torch.overrides",
|
|
"torch.package",
|
|
"torch.profiler",
|
|
"torch.serialization",
|
|
"torch.storage",
|
|
"torch.utils",
|
|
]
|
|
if config.trace_distributed:
|
|
disallowed_modules.append("torch.distributed.")
|
|
|
|
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
|
|
module = inspect.getmodule(obj)
|
|
if module is None:
|
|
return False
|
|
|
|
mod_name = module.__name__
|
|
|
|
if any(mod_name.startswith(m) for m in disallowed_modules):
|
|
return False
|
|
|
|
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
|
|
|
|
def _find_torch_objects(module):
|
|
if any(
|
|
module.__name__.startswith(mod_name)
|
|
for mod_name in config.allowed_functions_module_string_ignorelist
|
|
):
|
|
return
|
|
torch_object_ids[id(module)] = module.__name__
|
|
for name, obj in list(module.__dict__.items()):
|
|
if id(obj) not in torch_object_ids:
|
|
# Dynamo allows all builtins into the graph and does not attempt
|
|
# to introspect into them. We don't want to allow instances of
|
|
# HigherOrderOperator into the graph all the time (Dynamo needs
|
|
# to introspect the body functions of these HigherOrderOperator
|
|
# first, decide they are safe, and then allow them into the graph).
|
|
# So we exclude HigherOrderOperator from being a builtin.
|
|
import torch._ops
|
|
|
|
if isinstance(obj, torch._ops.HigherOrderOperator):
|
|
continue
|
|
|
|
# We want to trace through `grad` and `vmap`
|
|
if obj in (
|
|
torch.func.grad,
|
|
deprecated_func.grad,
|
|
torch.func.vmap,
|
|
deprecated_func.vmap,
|
|
torch.nn.functional.triplet_margin_with_distance_loss,
|
|
torch.cond,
|
|
):
|
|
continue
|
|
|
|
if isinstance(obj, types.ModuleType):
|
|
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
|
|
obj
|
|
):
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
_find_torch_objects(obj)
|
|
elif _is_allowed_module_prefix(obj):
|
|
if record:
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
|
|
if record:
|
|
heuristic_record_if_in_graph_function(obj, module, name)
|
|
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
|
|
|
|
_find_torch_objects(torch)
|
|
_find_torch_objects(math)
|
|
|
|
return AllowedObjects(
|
|
torch_object_ids,
|
|
c_binding_in_graph_functions,
|
|
non_c_binding_in_graph_functions,
|
|
torch_name_rule_map,
|
|
)
|
|
|
|
|
|
class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|
def _check_set_equality(self, generated, used, rule_map, ignored_set):
|
|
x = generated - used
|
|
y = used - generated
|
|
msg1 = (
|
|
f"New torch objects: {x} "
|
|
f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
|
|
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
|
|
)
|
|
msg2 = (
|
|
f"Existing torch objects: {y} were removed. "
|
|
f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
|
|
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
|
|
)
|
|
self.assertTrue(len(x) == 0, msg1)
|
|
self.assertTrue(len(y) == 0, msg2)
|
|
|
|
# We are using python function and module string names for these inlinelist,
|
|
# this unit test is to make sure the functions/modules can be correctly imported
|
|
# or loaded in case there is typo in the strings.
|
|
def test_skipfiles_inlinelist(self):
|
|
for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
|
|
self.assertTrue(
|
|
isinstance(importlib.import_module(m), types.ModuleType),
|
|
f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
|
|
)
|
|
|
|
def test_torch_name_rule_map_updated(self):
|
|
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
|
|
objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True)
|
|
# Test C binding in graph functions are updated in torch_name_rule_map.
|
|
generated = objs.c_binding_in_graph_functions
|
|
used = set()
|
|
for x in (
|
|
set(torch_c_binding_in_graph_functions.keys())
|
|
| ignored_c_binding_in_graph_function_names
|
|
):
|
|
obj = load_object(x)
|
|
if obj is not None:
|
|
used.add(obj)
|
|
self._check_set_equality(
|
|
generated,
|
|
used,
|
|
"torch_c_binding_in_graph_functions",
|
|
"ignored_c_binding_in_graph_function_names",
|
|
)
|
|
# For non C binding in graph functions, we only test if they can be loaded successfully.
|
|
for f in torch_non_c_binding_in_graph_functions:
|
|
self.assertTrue(
|
|
isinstance(
|
|
load_object(f),
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
types.MethodDescriptorType,
|
|
types.WrapperDescriptorType,
|
|
),
|
|
)
|
|
)
|
|
|
|
def test_force_inline_torch_function(self):
|
|
# `torch._dynamo.utils.istype` is skipped by default
|
|
def fn(x):
|
|
if istype(x, torch.Tensor):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
|
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
|
|
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
|
|
|
|
_torch_name_rule_map = [
|
|
_manual_torch_name_rule_map,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
]
|
|
|
|
self.assertTrue(
|
|
"torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST
|
|
)
|
|
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
|
|
|
|
with unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
|
_torch_name_rule_map,
|
|
), unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
|
):
|
|
x = torch.rand(3)
|
|
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_force_inline_custom_function(self):
|
|
mod, func = create_dummy_module_and_function()
|
|
|
|
def fn(x):
|
|
return func(x)
|
|
|
|
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
|
# Force inline `mod.func` by setting trace rule.
|
|
_manual_torch_name_rule_map[
|
|
f"{mod.__name__}.{func.__name__}"
|
|
] = UserFunctionVariable
|
|
|
|
_torch_name_rule_map = [
|
|
_manual_torch_name_rule_map,
|
|
torch_c_binding_in_graph_functions,
|
|
torch_non_c_binding_in_graph_functions,
|
|
]
|
|
|
|
with unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
|
_torch_name_rule_map,
|
|
), unittest.mock.patch(
|
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
|
):
|
|
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
|
torch._dynamo.trace_rules.add(mod.__name__)
|
|
x = torch.rand(3)
|
|
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|