mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Recently I feel it's a bit painful to run benchmark scripts on my dev environment. E.g., the command below
```
python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only YituTechConvBert --training
```
took about 2 minutes to run. It may take even longer for some other models.
The command is slow since it
- need do dynamo work
- verify the model on CPU
- run perf tests
- compile all the graphs
However, often times I only need to debug inductor specific logic like loop ordering and fusion. A lot of the things the script is done are useless for me. Also I only need test one graph at a time (e.g. check fwd graph first and when I'm done, continue to check bwd graph) rather than compiling all the graphs.
The graph replayer add a `@save_args` decorator to compile_fx_inner function. When `config.save_args` is true, it will pickle all the arguments to `comple_fx_inner` to the file system. Later on, we can call `load_args_and_run_compile_fx_inner("/tmp/inductor_saved_args/compile_fx_inner_0.pkl")` to replay the graph and compile it with inductor.
Replaying the fwd graph took around 60 seconds (maybe this can be further reduced but this is already 2x speedup for dev efficiency) , and it only took around 20 seconds to reach `Scheduler.__init__` method.
I also checked `TORCH_COMPILE_DEBUG` flag that already exists. The most similar part of `TORCH_COMPILE_DEBUG` is it can save a graph and it's arguments and later on rerun it. But the difference here is, rather than run the model, we want to call inductor API to compile the model (without even going thru dynamo or aot-autograd).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106952
Approved by: https://github.com/jansel
ghstack dependencies: #106990
472 lines
14 KiB
Python
472 lines
14 KiB
Python
import collections
|
|
import contextlib
|
|
import cProfile
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import os.path
|
|
import pickle
|
|
import pstats
|
|
import shutil
|
|
import subprocess
|
|
from typing import Any, List
|
|
from unittest.mock import patch
|
|
|
|
from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
|
|
|
|
import torch
|
|
from torch import fx as fx
|
|
|
|
from torch._dynamo.repro.after_aot import save_graph_repro, wrap_compiler_debug
|
|
from torch._dynamo.utils import get_debug_dir
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
|
from torch.fx.passes.tools_common import legalize_graph
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from . import config, ir # noqa: F811, this is needed
|
|
from .scheduler import (
|
|
BaseSchedulerNode,
|
|
FusedSchedulerNode,
|
|
NopKernelSchedulerNode,
|
|
OutputNode,
|
|
SchedulerNode,
|
|
)
|
|
from .virtualized import V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_dot():
|
|
try:
|
|
subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
|
|
return True
|
|
except subprocess.SubprocessError:
|
|
return False
|
|
|
|
|
|
def draw_buffers(nodes, print_graph=False, fname=None):
|
|
"""
|
|
Draw a graph in fname.svg.
|
|
nodes is a list of SchedulerNode objects.
|
|
"""
|
|
if not has_dot():
|
|
log.warning("draw_buffers() requires `graphviz` package")
|
|
return
|
|
|
|
if fname is None:
|
|
fname = get_graph_being_compiled()
|
|
|
|
graph = create_fx_from_snodes(nodes)
|
|
|
|
for node in graph.nodes:
|
|
if "fusion_meta" not in node.meta:
|
|
continue
|
|
group = node.meta["fusion_meta"].group
|
|
if isinstance(group, tuple):
|
|
group = group[1]
|
|
|
|
# gather meta data
|
|
dtype = None
|
|
if isinstance(node, ir.ComputedBuffer):
|
|
dtype = node.data.dtype
|
|
|
|
metadata = TensorMetadata(group, dtype, None, None, None, None, None)
|
|
node.meta["tensor_meta"] = metadata
|
|
|
|
if print_graph:
|
|
print(graph)
|
|
|
|
gm = GraphModule({}, graph)
|
|
legalize_graph(gm)
|
|
gm.graph.lint()
|
|
draw_graph(gm, fname, clear_meta=False)
|
|
|
|
|
|
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|
"""
|
|
Creates a FX Graph from a list of SchedulerNode objects.
|
|
"""
|
|
|
|
def get_fake_func(name):
|
|
def func1(*args):
|
|
return 0
|
|
|
|
func1.__name__ = name
|
|
return func1
|
|
|
|
FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
|
|
|
|
buf_to_fx_node = {}
|
|
graph = torch.fx.Graph()
|
|
first_node = None
|
|
|
|
outputs = []
|
|
group: Any = None
|
|
# create call_function node for each Buffer and Kernel
|
|
for snode in snodes:
|
|
if snode.is_extern():
|
|
node_type = "extern"
|
|
group = node_type
|
|
elif snode.is_template():
|
|
node_type = "template"
|
|
group = node_type
|
|
elif isinstance(snode, NopKernelSchedulerNode):
|
|
node_type = "nop"
|
|
group = node_type
|
|
elif isinstance(snode, SchedulerNode):
|
|
node_type = "compute"
|
|
group = snode.group
|
|
elif isinstance(snode, FusedSchedulerNode):
|
|
node_type = "fused"
|
|
group = snode.group
|
|
else:
|
|
raise RuntimeError("Unknown node type")
|
|
|
|
fused_name = torch._inductor.utils.get_fused_kernel_name(
|
|
snode.get_nodes(), "original_aten"
|
|
)
|
|
func_name = f"{node_type}: {fused_name}"
|
|
node_func = get_fake_func(func_name)
|
|
fx_node = graph.call_function(node_func, args=(), kwargs=None)
|
|
|
|
def in_output(snode):
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
return any(in_output(x) for x in snode.snodes)
|
|
return any(isinstance(user.node, OutputNode) for user in snode.users)
|
|
|
|
if in_output(snode):
|
|
outputs.append(fx_node)
|
|
name = snode.get_name()
|
|
fx_node.name = name
|
|
|
|
fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
|
|
|
|
if isinstance(snode, FusedSchedulerNode):
|
|
for x in snode.snodes:
|
|
buf_to_fx_node[x.get_name()] = fx_node
|
|
buf_to_fx_node[name] = fx_node
|
|
|
|
if first_node is None:
|
|
first_node = fx_node
|
|
|
|
# create edges between nodes
|
|
for snode in snodes:
|
|
name = snode.get_name()
|
|
deps = snode.read_writes.reads
|
|
|
|
fx_node = buf_to_fx_node[name]
|
|
new_args = []
|
|
for dep in deps:
|
|
if dep.name in buf_to_fx_node:
|
|
dep_node = buf_to_fx_node[dep.name]
|
|
else:
|
|
with graph.inserting_before(first_node):
|
|
dep_node = graph.placeholder(dep.name)
|
|
buf_to_fx_node[dep.name] = dep_node
|
|
new_args.append(dep_node)
|
|
|
|
fx_node.args = tuple(new_args)
|
|
|
|
graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
|
|
return graph
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def enable_aot_logging():
|
|
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
|
|
|
|
import torch._functorch.aot_autograd
|
|
|
|
log = logging.getLogger(torch._functorch.aot_autograd.__name__)
|
|
|
|
stack = contextlib.ExitStack()
|
|
if not compile_debug:
|
|
try:
|
|
yield
|
|
finally:
|
|
stack.close()
|
|
return
|
|
|
|
# Enable all graphs to be logged to a file by setting the flags to True
|
|
# and the log level of the file logger to DEBUG
|
|
stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
|
|
|
|
path = os.path.join(get_debug_dir(), "torchinductor")
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
fh = logging.FileHandler(
|
|
os.path.join(
|
|
path,
|
|
f"aot_{get_aot_graph_name()}_debug.log",
|
|
)
|
|
)
|
|
fh.setLevel(logging.DEBUG)
|
|
fh.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(fh)
|
|
try:
|
|
yield
|
|
finally:
|
|
log.removeHandler(fh)
|
|
stack.close()
|
|
|
|
|
|
class DebugContext:
|
|
_counter = itertools.count()
|
|
|
|
@staticmethod
|
|
def wrap(fn):
|
|
@functools.wraps(fn)
|
|
def inner(*args, **kwargs):
|
|
with DebugContext():
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrap_compiler_debug(inner, compiler_name="inductor")
|
|
|
|
@staticmethod
|
|
def create_debug_dir(folder_name):
|
|
for n in DebugContext._counter:
|
|
dirname = os.path.join(
|
|
get_debug_dir(),
|
|
"torchinductor",
|
|
f"{folder_name}.{n}",
|
|
)
|
|
if not os.path.exists(dirname):
|
|
os.makedirs(dirname)
|
|
return dirname
|
|
|
|
def __init__(self):
|
|
self._prof = None
|
|
self._path = None
|
|
self._stack = contextlib.ExitStack()
|
|
|
|
def copy(self, new_path: str):
|
|
if not self._path:
|
|
return
|
|
assert new_path.endswith(".debug"), new_path
|
|
if os.path.exists(new_path):
|
|
shutil.rmtree(new_path)
|
|
try:
|
|
shutil.copytree(self._path, new_path)
|
|
self._path = new_path
|
|
except OSError:
|
|
log.warning(
|
|
"Failed to copy debug files from %s to %s", self._path, new_path
|
|
)
|
|
pass
|
|
|
|
def fopen(self, filename):
|
|
assert self._path
|
|
return open(os.path.join(self._path, filename), "w")
|
|
|
|
def filename(self, suffix):
|
|
return os.path.join(self._path, suffix)
|
|
|
|
def upload_tar(self):
|
|
if config.trace.upload_tar is not None:
|
|
import tarfile
|
|
|
|
assert self._path
|
|
tar_file = os.path.join(
|
|
self._path, f"{os.path.basename(self._path)}.tar.gz"
|
|
)
|
|
with tarfile.open(tar_file, "w:gz") as tar:
|
|
tar.add(self._path, arcname=os.path.basename(self._path))
|
|
config.trace.upload_tar(tar_file)
|
|
|
|
def __enter__(self):
|
|
if config.debug:
|
|
log = logging.getLogger("torch._dynamo")
|
|
prev_level = log.level
|
|
log.setLevel(logging.DEBUG)
|
|
|
|
def reset_log_level(level):
|
|
log.setLevel(level)
|
|
|
|
self._stack.callback(reset_log_level, prev_level)
|
|
|
|
self._stack.enter_context(V.set_debug_handler(self))
|
|
|
|
if not config.trace.enabled:
|
|
return
|
|
|
|
self._path = self.create_debug_dir(get_aot_graph_name())
|
|
|
|
if config.trace.debug_log:
|
|
self._setup_log_capture("debug.log", logging.DEBUG)
|
|
if config.trace.info_log:
|
|
self._setup_log_capture("info.log", logging.INFO)
|
|
if config.trace.compile_profile:
|
|
self._prof = cProfile.Profile()
|
|
self._prof.enable()
|
|
|
|
def _setup_log_capture(self, filename, level):
|
|
log = logging.getLogger("torch._inductor")
|
|
fd = self._stack.enter_context(self.fopen(filename))
|
|
ch = logging.StreamHandler(fd)
|
|
ch.setLevel(level)
|
|
ch.setFormatter(
|
|
logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
|
|
)
|
|
log.addHandler(ch)
|
|
log.setLevel(min(log.level, level))
|
|
self._stack.callback(log.removeHandler, ch)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self._prof:
|
|
self._prof.disable()
|
|
self._save_profile_data()
|
|
|
|
if self._path:
|
|
self.upload_tar()
|
|
log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
|
|
self._stack.close()
|
|
|
|
def _save_profile_data(self):
|
|
self._prof.dump_stats(self.filename("compile.prof"))
|
|
with self.fopen("compile.stats") as fd:
|
|
stats = pstats.Stats(self._prof, stream=fd)
|
|
stats.strip_dirs()
|
|
stats.sort_stats("cumtime")
|
|
stats.print_stats(100)
|
|
stats.sort_stats("tottime")
|
|
stats.print_stats(100)
|
|
|
|
def __getattr__(self, name):
|
|
if config.trace.enabled and getattr(config.trace, name):
|
|
try:
|
|
return getattr(DebugFormatter(self), name)
|
|
except Exception:
|
|
log.warning("Ignoring exception in debug code", exc_info=True)
|
|
else:
|
|
|
|
def ignored(*args, **kwargs):
|
|
pass
|
|
|
|
return ignored
|
|
|
|
|
|
SchedulerNodeList = List[Any]
|
|
|
|
|
|
class DebugFormatter:
|
|
def __init__(self, handler):
|
|
self.fopen = handler.fopen
|
|
self.filename = handler.filename
|
|
self.handler = handler
|
|
|
|
def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
|
|
with self.fopen("fx_graph_runnable.py") as fd:
|
|
save_graph_repro(fd, gm, inputs, "inductor")
|
|
|
|
with self.fopen("fx_graph_readable.py") as fd:
|
|
fd.write(gm.print_readable(print_output=False))
|
|
|
|
def fx_graph_transformed(
|
|
self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
|
|
):
|
|
with self.fopen("fx_graph_transformed.py") as fd:
|
|
fd.write(gm.print_readable(print_output=False))
|
|
|
|
def ir_pre_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_pre_fusion.txt", nodes)
|
|
|
|
def ir_post_fusion(self, nodes: SchedulerNodeList):
|
|
self._write_ir("ir_post_fusion.txt", nodes)
|
|
|
|
def _write_ir(self, filename: str, nodes: SchedulerNodeList):
|
|
with self.fopen(filename) as fd:
|
|
for node in nodes:
|
|
fd.write(node.debug_str())
|
|
fd.write("\n\n\n")
|
|
|
|
def graph_diagram(self, nodes: SchedulerNodeList):
|
|
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
|
|
|
|
def output_code(self, filename):
|
|
shutil.copy(filename, self.filename("output_code.py"))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorMetadataHolder:
|
|
tensor_metadata: TensorMetadata
|
|
device: torch.device
|
|
|
|
|
|
save_args_cnt = itertools.count()
|
|
|
|
|
|
def save_args_for_compile_fx_inner(*args, **kwargs):
|
|
"""
|
|
This function is used to save arguments for a compile_fx_inner function call
|
|
to the file system. Later on one can replay the compile_fx_inner call
|
|
with the saved arguments using load_args_and_run_compile_fx_inner.
|
|
"""
|
|
|
|
folder = "/tmp/inductor_saved_args"
|
|
if not os.path.exists(folder):
|
|
os.mkdir(folder)
|
|
|
|
def handle_tensor(x):
|
|
"""
|
|
Pickle FakeTensor will result in error:
|
|
AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
|
|
|
|
Convert all Tensor to metadata. This may also makes pickle faster.
|
|
"""
|
|
if isinstance(x, torch.Tensor):
|
|
return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
|
|
else:
|
|
return x
|
|
|
|
args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
|
|
|
|
fn_name = "compile_fx_inner"
|
|
path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
|
|
with open(path, "wb") as f:
|
|
pickle.dump((args_to_save, kwargs_to_save), f)
|
|
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
message = f"""
|
|
Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
|
|
run the following:
|
|
|
|
from torch._inductor.debug import load_args_and_run_compile_fx_inner
|
|
load_args_and_run_compile_fx_inner({path!r})
|
|
"""
|
|
# call print rather than log.debug. log.debug will print message
|
|
# prefix for each line which makes the code snippet harder to be
|
|
# copied.
|
|
# Not a big deal since the code is already been guarded by checking
|
|
# the log level.
|
|
print(message)
|
|
|
|
|
|
def load_args_and_run_compile_fx_inner(path):
|
|
from torch._inductor.compile_fx import compile_fx_inner
|
|
|
|
with open(path, "rb") as f:
|
|
args, kwargs = pickle.load(f)
|
|
|
|
def handle_tensor(x):
|
|
if isinstance(x, TensorMetadataHolder):
|
|
return torch._dynamo.testing.rand_strided(
|
|
x.tensor_metadata.shape,
|
|
x.tensor_metadata.stride,
|
|
x.tensor_metadata.dtype,
|
|
x.device,
|
|
)
|
|
else:
|
|
return x
|
|
|
|
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
|
with fake_mode, config.patch("save_args", False):
|
|
args, kwargs = tree_map(handle_tensor, (args, kwargs))
|
|
return compile_fx_inner(*args, **kwargs)
|