mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This reverts commit deaf9e5e65.
Reverted https://github.com/pytorch/pytorch/pull/95985 on behalf of https://github.com/huydhn due to Sorry for reverting this. It increased the test time significantly for ASAN (and may be other test shards). ASAN tests on PR passed but it was barely not timing out. I have updated my initial findings in https://github.com/pytorch/pytorch/issues/96378
405 lines
12 KiB
Python
405 lines
12 KiB
Python
import collections
|
|
import contextlib
|
|
import cProfile
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import os.path
|
|
import pstats
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from typing import Any, List
|
|
from unittest.mock import patch
|
|
|
|
from functorch.compile import (
|
|
config as functorch_config,
|
|
draw_graph,
|
|
get_aot_graph_name,
|
|
get_graph_being_compiled,
|
|
)
|
|
|
|
import torch
|
|
from torch import fx as fx
|
|
|
|
from torch._dynamo import config as dynamo_config
|
|
from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
|
|
from torch._dynamo.utils import get_debug_dir, init_logging
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
from torch.fx.passes.tools_common import legalize_graph
|
|
|
|
from . import config, ir # noqa: F811, this is needed
|
|
from .scheduler import (
|
|
BaseSchedulerNode,
|
|
FusedSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
OutputNode,
|
|
SchedulerNode,
|
|
)
|
|
from .virtualized import V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_dot():
|
|
try:
|
|
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
|
|
return True
|
|
except subprocess.SubprocessError:
|
|
return False
|
|
|
|
|
|
def draw_buffers(nodes, print_graph=False, fname=None):
|
|
"""
|
|
Draw a graph in fname.svg.
|
|
nodes is a list of SchedulerNode objects.
|
|
"""
|
|
if not has_dot():
|
|
log.warning("draw_buffers() requires `graphviz` package")
|
|
return
|
|
|
|
if fname is None:
|
|
fname = get_graph_being_compiled()
|
|
|
|
graph = create_fx_from_snodes(nodes)
|
|
|
|
for node in graph.nodes:
|
|
if "fusion_meta" not in node.meta:
|
|
continue
|
|
group = node.meta["fusion_meta"].group
|
|
if isinstance(group, tuple):
|
|
group = group[1]
|
|
|
|
# gather meta data
|
|
dtype = None
|
|
if isinstance(node, ir.ComputedBuffer):
|
|
dtype = node.data.dtype
|
|
|
|
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
|
|
node.meta["tensor_meta"] = metadata
|
|
|
|
if print_graph:
|
|
print(graph)
|
|
|
|
gm = GraphModule({}, graph)
|
|
legalize_graph(gm)
|
|
gm.graph.lint()
|
|
draw_graph(gm, fname, clear_meta=False)
|
|
|
|
|
|
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|
"""
|
|
Creates a FX Graph from a list of SchedulerNode objects.
|
|
"""
|
|
|
|
def get_fake_func(name):
|
|
def func1(*args):
|
|
return 0
|
|
|
|
func1.__name__ = name
|
|
return func1
|
|
|
|
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
|
|
|
|
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
|
|
buf_to_fx_node = {}
|
|
graph = torch.fx.Graph()
|
|
first_node = None
|
|
|
|
outputs = []
|
|
group: Any = None
|
|
# create call_function node for each Buffer and Kernel
|
|
for snode in snodes:
|
|
if snode.is_extern():
|
|
node_type = "extern"
|
|
group = node_type
|
|
elif snode.is_template():
|
|
node_type = "template"
|
|
group = node_type
|
|
elif isinstance(snode, NopKernelSchedulerNode):
|
|
node_type = "nop"
|
|
group = node_type
|
|
elif isinstance(snode, SchedulerNode):
|
|
node_type = "compute"
|
|
group = snode.group
|
|
elif isinstance(snode, FusedSchedulerNode):
|
|
node_type = "fused"
|
|
group = snode.group
|
|
else:
|
|
raise RuntimeError("Unknown node type")
|
|
node_func = func_dict[node_type]
|
|
fx_node = graph.call_function(node_func, args=(), kwargs=None)
|
|
|
|
def in_output(snode):
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
return any([in_output(x) for x in snode.snodes])
|
|
return any([isinstance(user.node, OutputNode) for user in snode.users])
|
|
|
|
if in_output(snode):
|
|
outputs.append(fx_node)
|
|
name = snode.get_name()
|
|
fx_node.name = name
|
|
|
|
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
|
|
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
for x in snode.snodes:
|
|
buf_to_fx_node[x.get_name()] = fx_node
|
|
buf_to_fx_node[name] = fx_node
|
|
|
|
if first_node is None:
|
|
first_node = fx_node
|
|
|
|
# create edges between nodes
|
|
for snode in snodes:
|
|
name = snode.get_name()
|
|
deps = snode.read_writes.reads
|
|
|
|
fx_node = buf_to_fx_node[name]
|
|
new_args = []
|
|
for dep in deps:
|
|
if dep.name in buf_to_fx_node:
|
|
dep_node = buf_to_fx_node[dep.name]
|
|
else:
|
|
with graph.inserting_before(first_node):
|
|
dep_node = graph.placeholder(dep.name)
|
|
buf_to_fx_node[dep.name] = dep_node
|
|
new_args.append(dep_node)
|
|
|
|
fx_node.args = tuple(new_args)
|
|
|
|
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
|
|
return graph
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_aot_logging():
|
|
compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
|
|
debug_graphs = functorch_config.debug_graphs
|
|
debug_joint_graphs = functorch_config.debug_joint
|
|
|
|
import torch._functorch.aot_autograd
|
|
|
|
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
|
|
|
|
stack = contextlib.ExitStack()
|
|
stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
|
|
# if user has specified they want to see graphs via either env var
|
|
# add stream to std out
|
|
if debug_graphs or debug_joint_graphs:
|
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
log.addHandler(stdout_handler)
|
|
stack.callback(lambda: log.removeHandler(stdout_handler))
|
|
|
|
if not compile_debug:
|
|
try:
|
|
yield
|
|
finally:
|
|
stack.close()
|
|
return
|
|
|
|
# Enable all graphs to be logged to a file by setting the flags to True
|
|
# and the log level of the file logger to DEBUG
|
|
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
|
|
stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
|
|
stack.enter_context(patch("functorch.compile.config.debug_joint", True))
|
|
|
|
path = os.path.join(get_debug_dir(), "aot_torchinductor")
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
fh = logging.FileHandler(
|
|
os.path.join(
|
|
path,
|
|
f"aot_{get_aot_graph_name()}_debug.log",
|
|
)
|
|
)
|
|
fh.setLevel(logging.DEBUG)
|
|
fh.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(fh)
|
|
try:
|
|
yield
|
|
finally:
|
|
log.removeHandler(fh)
|
|
stack.close()
|
|
|
|
|
|
class DebugContext:
|
|
_counter = itertools.count()
|
|
|
|
@staticmethod
|
|
def wrap(fn):
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
with DebugContext():
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrap_compiler_debug(inner, compiler_name="inductor")
|
|
|
|
@staticmethod
|
|
def create_debug_dir(folder_name):
|
|
for n in DebugContext._counter:
|
|
dirname = os.path.join(
|
|
get_debug_dir(),
|
|
"aot_torchinductor",
|
|
f"{folder_name}.{n}",
|
|
)
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
return dirname
|
|
|
|
def __init__(self):
|
|
self._prof = None
|
|
self._path = None
|
|
self._stack = contextlib.ExitStack()
|
|
|
|
def rename(self, new_path: str):
|
|
if not self._path:
|
|
return
|
|
assert new_path.endswith(".debug"), new_path
|
|
if os.path.exists(new_path):
|
|
shutil.rmtree(new_path)
|
|
try:
|
|
os.rename(self._path, new_path)
|
|
self._path = new_path
|
|
except OSError:
|
|
# other OS might have troubling renaming dir with open files
|
|
pass
|
|
|
|
def fopen(self, filename):
|
|
assert self._path
|
|
return open(os.path.join(self._path, filename), "w")
|
|
|
|
def filename(self, suffix):
|
|
return os.path.join(self._path, suffix)
|
|
|
|
def upload_tar(self):
|
|
if config.trace.upload_tar is not None:
|
|
import tarfile
|
|
|
|
assert self._path
|
|
tar_file = os.path.join(
|
|
self._path, f"{os.path.basename(self._path)}.tar.gz"
|
|
)
|
|
with tarfile.open(tar_file, "w:gz") as tar:
|
|
tar.add(self._path, arcname=os.path.basename(self._path))
|
|
config.trace.upload_tar(tar_file)
|
|
|
|
def __enter__(self):
|
|
log = logging.getLogger("torch._inductor")
|
|
if not log.handlers:
|
|
init_logging()
|
|
|
|
if config.debug:
|
|
|
|
def reset_log_level(level):
|
|
dynamo_config.log_level = level
|
|
|
|
self._stack.callback(reset_log_level, dynamo_config.log_level)
|
|
dynamo_config.log_level = logging.DEBUG
|
|
|
|
self._stack.enter_context(V.set_debug_handler(self))
|
|
|
|
if not config.trace.enabled:
|
|
return
|
|
|
|
self._path = self.create_debug_dir(get_aot_graph_name())
|
|
|
|
if config.trace.debug_log:
|
|
self._setup_log_capture("debug.log", logging.DEBUG)
|
|
if config.trace.info_log:
|
|
self._setup_log_capture("info.log", logging.INFO)
|
|
if config.trace.compile_profile:
|
|
self._prof = cProfile.Profile()
|
|
self._prof.enable()
|
|
|
|
def _setup_log_capture(self, filename, level):
|
|
log = logging.getLogger("torch._inductor")
|
|
fd = self._stack.enter_context(self.fopen(filename))
|
|
ch = logging.StreamHandler(fd)
|
|
ch.setLevel(level)
|
|
ch.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(ch)
|
|
log.setLevel(min(log.level, level))
|
|
self._stack.callback(log.removeHandler, ch)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self._prof:
|
|
self._prof.disable()
|
|
self._save_profile_data()
|
|
|
|
if self._path:
|
|
self.upload_tar()
|
|
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
|
|
self._stack.close()
|
|
|
|
def _save_profile_data(self):
|
|
self._prof.dump_stats(self.filename("compile.prof"))
|
|
with self.fopen("compile.stats") as fd:
|
|
stats = pstats.Stats(self._prof, stream=fd)
|
|
stats.strip_dirs()
|
|
stats.sort_stats("cumtime")
|
|
stats.print_stats(100)
|
|
stats.sort_stats("tottime")
|
|
stats.print_stats(100)
|
|
|
|
def __getattr__(self, name):
|
|
if config.trace.enabled and getattr(config.trace, name):
|
|
try:
|
|
return getattr(DebugFormatter(self), name)
|
|
except Exception:
|
|
log.warning("Ignoring exception in debug code", exc_info=True)
|
|
else:
|
|
|
|
def ignored(*args, **kwargs):
|
|
pass
|
|
|
|
return ignored
|
|
|
|
|
|
SchedulerNodeList = List[Any]
|
|
|
|
|
|
class DebugFormatter:
|
|
def __init__(self, handler):
|
|
self.fopen = handler.fopen
|
|
self.filename = handler.filename
|
|
self.handler = handler
|
|
|
|
def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
|
|
with self.fopen("fx_graph_runnable.py") as fd:
|
|
save_graph_repro(fd, gm, inputs, "inductor")
|
|
|
|
with self.fopen("fx_graph_readable.py") as fd:
|
|
fd.write(gm.print_readable(print_output=False))
|
|
|
|
def fx_graph_transformed(
|
|
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
|
|
):
|
|
with self.fopen("fx_graph_transformed.py") as fd:
|
|
fd.write(gm.print_readable(print_output=False))
|
|
|
|
def ir_pre_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_pre_fusion.txt", nodes)
|
|
|
|
def ir_post_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_post_fusion.txt", nodes)
|
|
|
|
def _write_ir(self, filename: str, nodes: SchedulerNodeList):
|
|
with self.fopen(filename) as fd:
|
|
for node in nodes:
|
|
fd.write(node.debug_str())
|
|
fd.write("\n\n\n")
|
|
|
|
def graph_diagram(self, nodes: SchedulerNodeList):
|
|
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
|
|
|
|
def output_code(self, filename):
|
|
shutil.copy(filename, self.filename("output_code.py"))
|