pytorch/scripts/jit/log_extract.py
David Berard b27ec57331 [JIT] script & logging for extracting IR from logs (#72889)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72889

The script along with the GRAPH_EXPORT macro will allow for an easy way to extract IR from logs. One use case in this diff is to extract the fusion groups from nvfuser, so that the fusions can be tested individually.

Usage (e.g. for nvfuser test)

1. Write some test.py file that uses nvfuser
2. `PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 test.py 2>&1 | tee output.txt`
3. `python3 pytorch/scripts/jit/log_extract.py output.txt --nvfuser`

This will run with and without nvfuser to compare the output.

Alternatively, use `--output` to dump the IR so that it can be used in other applications.

Currently, only `--output` works (since generating input tensors is not supported)

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D34440189

Pulled By: davidberard98

fbshipit-source-id: fca0f619200ee37aba34bb39b69e6c640c263e26
(cherry picked from commit eb319166075db160f1628f0de545641fbecde8be)
2022-03-02 18:34:35 +00:00

166 lines
5.6 KiB
Python

from contextlib import contextmanager
from torch.testing import make_tensor
from typing import Any, List, Tuple
import argparse
import torch
'''
Usage:
1. Run your script and pipe into a log file
PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
2. Run log_extract:
log_extract.py log.txt --nvfuser
You can also extract the list of extracted IR:
log_extract.py log.txt --output
'''
def extract_ir(filename: str) -> List[str]:
BEGIN = "<GRAPH_EXPORT>"
END = "</GRAPH_EXPORT>"
pfx = None
current = ""
graphs = []
with open(filename, "r") as f:
split_strs = f.read().split(BEGIN)
for i, split_str in enumerate(split_strs):
if i == 0:
continue
end_loc = split_str.find(END)
if end_loc == -1:
continue
s = split_str[:end_loc]
pfx = split_strs[i - 1].splitlines()[-1]
lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
graphs.append(''.join(lines))
return graphs
def make_tensor_from_type(inp_type: torch._C.TensorType):
if inp_type.requires_grad() is not False:
raise NotImplementedError("Tensors with requires_grad are not implemented")
return make_tensor(
inp_type.sizes(),
dtype=inp_type.dtype(),
device=inp_type.device())
def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
graph = torch._C.parse_ir(ir)
graph.makeMultiOutputIntoTuple()
inputs = []
for inp in graph.inputs():
if isinstance(inp.type(), torch._C.FloatType):
inputs.append(.5)
elif isinstance(inp.type(), torch._C.IntType):
inputs.append(2)
elif isinstance(inp.type(), torch._C.TensorType):
inputs.append(make_tensor_from_type(inp.type()))
else:
raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
func = torch._C._create_function_from_graph("forward", graph)
torch._C._jit_pass_erase_shape_information(func.graph)
return (func, inputs)
# TODO add support for timing on CPU
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
graph, _ = load_graph_and_inputs(ir)
for _ in range(warmup_runs):
graph(*inputs)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
torch.cuda.synchronize()
for i in range(test_runs):
graph(*inputs)
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / test_runs
@contextmanager
def no_fuser(*args, **kwargs):
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
try:
yield
finally:
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
def run_baseline_no_fusion(ir, inputs) -> float:
with no_fuser():
return run_test(ir, inputs)
def run_nnc(ir, inputs) -> float:
with torch.jit.fuser("fuser1"):
return run_test(ir, inputs)
def run_nvfuser(ir, inputs) -> float:
with torch.jit.fuser("fuser2"):
return run_test(ir, inputs)
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
for i, ir in enumerate(graphs):
_, inputs = load_graph_and_inputs(ir)
baseline = baseline_fn(ir, inputs)
nvfuser = nvfuser_fn(ir, inputs)
improvement = (baseline / nvfuser - 1) * 100
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
def run():
parser = argparse.ArgumentParser(
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
)
parser.add_argument("filename", help="Filename of log file")
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion")
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion")
parser.set_defaults(nvfuser=False)
parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc")
parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc")
parser.set_defaults(nvfuser_nnc=False)
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
parser.set_defaults(output=False)
args = parser.parse_args()
graphs = extract_ir(args.filename)
if args.nvfuser:
print("NVFuser vs no fusion:")
test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser)
if args.nvfuser_nnc:
print("NVFuser vs NNC:")
test_nvfuser(graphs, run_nnc, run_nvfuser)
if args.output:
quoted = []
for ir in graphs:
quoted.append("\"\"\"" + ir + "\"\"\"")
print("[" + ", ".join(quoted) + "]")
if __name__ == "__main__":
run()