mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
When entering an inductor debug context we increase the log level of dynamo; I guess this makes sense, since if we're debugging inductor, and inductor calls into dynamo, we probably want visibility into what dynamo is doing. But when we exit that context, we probably want to go back to whatever level of dynamo-specific logging was in place before. Dynamo generates lots of debug info (guards, bytecode), and it's a lot to sift through if you're not specifically interested in it. Differential Revision: [D41841879](https://our.internmc.facebook.com/intern/diff/D41841879/) Differential Revision: [D41841879](https://our.internmc.facebook.com/intern/diff/D41841879) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90473 Approved by: https://github.com/mlazos, https://github.com/jansel
384 lines
12 KiB
Python
384 lines
12 KiB
Python
import collections
|
|
import contextlib
|
|
import cProfile
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import os.path
|
|
import pstats
|
|
import shutil
|
|
import subprocess
|
|
from typing import Any, List
|
|
from unittest.mock import patch
|
|
|
|
from functorch.compile import (
|
|
config,
|
|
draw_graph,
|
|
get_aot_graph_name,
|
|
get_graph_being_compiled,
|
|
)
|
|
|
|
import torch
|
|
from torch import fx as fx
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
from torch.fx.passes.tools_common import legalize_graph
|
|
|
|
from . import config, ir # noqa: F811, this is needed
|
|
from .scheduler import (
|
|
BaseSchedulerNode,
|
|
ExternKernelSchedulerNode,
|
|
FusedSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
OutputNode,
|
|
SchedulerNode,
|
|
TemplateSchedulerNode,
|
|
)
|
|
from .utils import dynamo_config, dynamo_debug_utils, dynamo_utils
|
|
from .virtualized import V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_dot():
|
|
try:
|
|
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
|
|
return True
|
|
except subprocess.SubprocessError:
|
|
return False
|
|
|
|
|
|
def draw_buffers(nodes, print_graph=False, fname=None):
|
|
"""
|
|
Draw a graph in fname.svg.
|
|
nodes is a list of SchedulerNode objects.
|
|
"""
|
|
if not has_dot():
|
|
log.warning("draw_buffers() requires `graphviz` package")
|
|
return
|
|
|
|
if fname is None:
|
|
fname = get_graph_being_compiled()
|
|
|
|
graph = create_fx_from_snodes(nodes)
|
|
|
|
for node in graph.nodes:
|
|
if "fusion_meta" not in node.meta:
|
|
continue
|
|
group = node.meta["fusion_meta"].group
|
|
if isinstance(group, tuple):
|
|
group = group[1]
|
|
|
|
# gather meta data
|
|
dtype = None
|
|
if isinstance(node, ir.ComputedBuffer):
|
|
dtype = node.data.dtype
|
|
|
|
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
|
|
node.meta["tensor_meta"] = metadata
|
|
|
|
if print_graph:
|
|
print(graph)
|
|
|
|
gm = GraphModule({}, graph)
|
|
legalize_graph(gm)
|
|
gm.graph.lint()
|
|
draw_graph(gm, fname, clear_meta=False)
|
|
|
|
|
|
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|
"""
|
|
Creates a FX Graph from a list of SchedulerNode objects.
|
|
"""
|
|
|
|
def get_fake_func(name):
|
|
def func1(*args):
|
|
return 0
|
|
|
|
func1.__name__ = name
|
|
return func1
|
|
|
|
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
|
|
|
|
func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
|
|
buf_to_fx_node = {}
|
|
graph = torch.fx.Graph()
|
|
first_node = None
|
|
|
|
outputs = []
|
|
group: Any = None
|
|
# create call_function node for each Buffer and Kernel
|
|
for snode in snodes:
|
|
if isinstance(snode, ExternKernelSchedulerNode):
|
|
node_type = "extern"
|
|
group = node_type
|
|
elif isinstance(snode, TemplateSchedulerNode):
|
|
node_type = "template"
|
|
group = node_type
|
|
elif isinstance(snode, NopKernelSchedulerNode):
|
|
node_type = "nop"
|
|
group = node_type
|
|
elif isinstance(snode, SchedulerNode):
|
|
node_type = "compute"
|
|
group = snode.group
|
|
elif isinstance(snode, FusedSchedulerNode):
|
|
node_type = "fused"
|
|
group = snode.group
|
|
else:
|
|
raise RuntimeError("Unknown node type")
|
|
node_func = func_dict[node_type]
|
|
fx_node = graph.call_function(node_func, args=(), kwargs=None)
|
|
|
|
def in_output(snode):
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
return any([in_output(x) for x in snode.snodes])
|
|
return any([isinstance(user.node, OutputNode) for user in snode.users])
|
|
|
|
if in_output(snode):
|
|
outputs.append(fx_node)
|
|
name = snode.get_name()
|
|
fx_node.name = name
|
|
|
|
fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
|
|
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
for x in snode.snodes:
|
|
buf_to_fx_node[x.get_name()] = fx_node
|
|
buf_to_fx_node[name] = fx_node
|
|
|
|
if first_node is None:
|
|
first_node = fx_node
|
|
|
|
# create edges between nodes
|
|
for snode in snodes:
|
|
name = snode.get_name()
|
|
deps = snode.read_writes.reads
|
|
|
|
fx_node = buf_to_fx_node[name]
|
|
new_args = []
|
|
for dep in deps:
|
|
if dep.name in buf_to_fx_node:
|
|
dep_node = buf_to_fx_node[dep.name]
|
|
else:
|
|
with graph.inserting_before(first_node):
|
|
dep_node = graph.placeholder(dep.name)
|
|
buf_to_fx_node[dep.name] = dep_node
|
|
new_args.append(dep_node)
|
|
|
|
fx_node.args = tuple(new_args)
|
|
|
|
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
|
|
return graph
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_aot_logging():
|
|
if not bool(os.environ.get("TORCH_COMPILE_DEBUG", False)):
|
|
yield
|
|
return
|
|
|
|
# Enable all graphs to be logged to a file by setting the flags to True
|
|
# and the log level of the file logger to DEBUG
|
|
|
|
stack = contextlib.ExitStack()
|
|
stack.enter_context(patch("functorch.compile.config.debug_fake_cross_ref", True))
|
|
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
|
|
stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
|
|
stack.enter_context(patch("functorch.compile.config.debug_joint", True))
|
|
stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
|
|
|
|
import torch._functorch.aot_autograd
|
|
|
|
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
|
|
path = os.path.join(dynamo_utils.get_debug_dir(), "aot_torchinductor")
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
fh = logging.FileHandler(
|
|
os.path.join(
|
|
path,
|
|
f"aot_{get_aot_graph_name()}_debug.log",
|
|
)
|
|
)
|
|
fh.setLevel(logging.DEBUG)
|
|
fh.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(fh)
|
|
try:
|
|
yield
|
|
finally:
|
|
log.removeHandler(fh)
|
|
stack.close()
|
|
|
|
|
|
class DebugContext:
|
|
_counter = itertools.count()
|
|
|
|
@staticmethod
|
|
def wrap(fn):
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
with DebugContext():
|
|
return fn(*args, **kwargs)
|
|
|
|
return dynamo_debug_utils.wrap_compiler_debug(inner, compiler_name="inductor")
|
|
|
|
@staticmethod
|
|
def create_debug_dir(folder_name):
|
|
for n in DebugContext._counter:
|
|
dirname = os.path.join(
|
|
dynamo_utils.get_debug_dir(),
|
|
"aot_torchinductor",
|
|
f"{folder_name}.{n}",
|
|
)
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
return dirname
|
|
|
|
def __init__(self):
|
|
self._prof = None
|
|
self._path = None
|
|
self._stack = contextlib.ExitStack()
|
|
|
|
def rename(self, new_path: str):
|
|
if not self._path:
|
|
return
|
|
assert new_path.endswith(".debug"), new_path
|
|
if os.path.exists(new_path):
|
|
shutil.rmtree(new_path)
|
|
try:
|
|
os.rename(self._path, new_path)
|
|
self._path = new_path
|
|
except OSError:
|
|
# other OS might have troubling renaming dir with open files
|
|
pass
|
|
|
|
def fopen(self, filename):
|
|
assert self._path
|
|
return open(os.path.join(self._path, filename), "w")
|
|
|
|
def filename(self, suffix):
|
|
return os.path.join(self._path, suffix)
|
|
|
|
def upload_tar(self):
|
|
if config.trace.upload_tar is not None:
|
|
import tarfile
|
|
|
|
assert self._path
|
|
tar_file = os.path.join(
|
|
self._path, f"{os.path.basename(self._path)}.tar.gz"
|
|
)
|
|
with tarfile.open(tar_file, "w:gz") as tar:
|
|
tar.add(self._path, arcname=os.path.basename(self._path))
|
|
config.trace.upload_tar(tar_file)
|
|
|
|
def __enter__(self):
|
|
log = logging.getLogger(config.inductor_import)
|
|
if not log.handlers:
|
|
dynamo_utils.init_logging()
|
|
|
|
if config.debug:
|
|
|
|
def reset_log_level(level):
|
|
dynamo_config.log_level = level
|
|
|
|
self._stack.callback(reset_log_level, dynamo_config.log_level)
|
|
dynamo_config.log_level = logging.DEBUG
|
|
|
|
self._stack.enter_context(V.set_debug_handler(self))
|
|
|
|
if not config.trace.enabled:
|
|
return
|
|
|
|
self._path = self.create_debug_dir(get_aot_graph_name())
|
|
|
|
if config.trace.debug_log:
|
|
self._setup_log_capture("debug.log", logging.DEBUG)
|
|
if config.trace.info_log:
|
|
self._setup_log_capture("info.log", logging.INFO)
|
|
if config.trace.compile_profile:
|
|
self._prof = cProfile.Profile()
|
|
self._prof.enable()
|
|
|
|
def _setup_log_capture(self, filename, level):
|
|
log = logging.getLogger(config.inductor_import)
|
|
fd = self._stack.enter_context(self.fopen(filename))
|
|
ch = logging.StreamHandler(fd)
|
|
ch.setLevel(level)
|
|
ch.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(ch)
|
|
log.setLevel(min(log.level, level))
|
|
self._stack.callback(log.removeHandler, ch)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self._prof:
|
|
self._prof.disable()
|
|
self._save_profile_data()
|
|
|
|
if self._path:
|
|
self.upload_tar()
|
|
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
|
|
self._stack.close()
|
|
|
|
def _save_profile_data(self):
|
|
self._prof.dump_stats(self.filename("compile.prof"))
|
|
with self.fopen("compile.stats") as fd:
|
|
stats = pstats.Stats(self._prof, stream=fd)
|
|
stats.strip_dirs()
|
|
stats.sort_stats("cumtime")
|
|
stats.print_stats(100)
|
|
stats.sort_stats("tottime")
|
|
stats.print_stats(100)
|
|
|
|
def __getattr__(self, name):
|
|
if config.trace.enabled and getattr(config.trace, name):
|
|
try:
|
|
return getattr(DebugFormatter(self), name)
|
|
except Exception:
|
|
log.warning("Ignoring exception in debug code", exc_info=True)
|
|
else:
|
|
|
|
def ignored(*args, **kwargs):
|
|
pass
|
|
|
|
return ignored
|
|
|
|
|
|
SchedulerNodeList = List[Any]
|
|
|
|
|
|
class DebugFormatter:
|
|
def __init__(self, handler):
|
|
self.fopen = handler.fopen
|
|
self.filename = handler.filename
|
|
self.handler = handler
|
|
|
|
def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
|
|
with self.fopen("fx_graph_runnable.py") as fd:
|
|
dynamo_debug_utils.save_graph_repro(fd, gm, inputs, "inductor")
|
|
|
|
with self.fopen("fx_graph_readable.py") as fd:
|
|
fd.write(gm.print_readable(print_output=False))
|
|
|
|
def ir_pre_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_pre_fusion.txt", nodes)
|
|
|
|
def ir_post_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_post_fusion.txt", nodes)
|
|
|
|
def _write_ir(self, filename: str, nodes: SchedulerNodeList):
|
|
with self.fopen(filename) as fd:
|
|
for node in nodes:
|
|
fd.write(node.debug_str())
|
|
fd.write("\n\n\n")
|
|
|
|
def graph_diagram(self, nodes: SchedulerNodeList):
|
|
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
|
|
|
|
def output_code(self, filename):
|
|
shutil.copy(filename, self.filename("output_code.py"))
|