mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[JIT] script & logging for extracting IR from logs (#72889)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72889 The script along with the GRAPH_EXPORT macro will allow for an easy way to extract IR from logs. One use case in this diff is to extract the fusion groups from nvfuser, so that the fusions can be tested individually. Usage (e.g. for nvfuser test) 1. Write some test.py file that uses nvfuser 2. `PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 test.py 2>&1 | tee output.txt` 3. `python3 pytorch/scripts/jit/log_extract.py output.txt --nvfuser` This will run with and without nvfuser to compare the output. Alternatively, use `--output` to dump the IR so that it can be used in other applications. Currently, only `--output` works (since generating input tensors is not supported) Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D34440189 Pulled By: davidberard98 fbshipit-source-id: fca0f619200ee37aba34bb39b69e6c640c263e26 (cherry picked from commit eb319166075db160f1628f0de545641fbecde8be)
This commit is contained in:
parent
b7a7cdd00a
commit
b27ec57331
165
scripts/jit/log_extract.py
Normal file
165
scripts/jit/log_extract.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
from contextlib import contextmanager
|
||||
from torch.testing import make_tensor
|
||||
from typing import Any, List, Tuple
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
'''
|
||||
Usage:
|
||||
1. Run your script and pipe into a log file
|
||||
PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
|
||||
2. Run log_extract:
|
||||
log_extract.py log.txt --nvfuser
|
||||
|
||||
You can also extract the list of extracted IR:
|
||||
log_extract.py log.txt --output
|
||||
'''
|
||||
|
||||
def extract_ir(filename: str) -> List[str]:
|
||||
BEGIN = "<GRAPH_EXPORT>"
|
||||
END = "</GRAPH_EXPORT>"
|
||||
pfx = None
|
||||
current = ""
|
||||
graphs = []
|
||||
with open(filename, "r") as f:
|
||||
split_strs = f.read().split(BEGIN)
|
||||
for i, split_str in enumerate(split_strs):
|
||||
if i == 0:
|
||||
continue
|
||||
end_loc = split_str.find(END)
|
||||
if end_loc == -1:
|
||||
continue
|
||||
s = split_str[:end_loc]
|
||||
pfx = split_strs[i - 1].splitlines()[-1]
|
||||
lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
|
||||
graphs.append(''.join(lines))
|
||||
|
||||
return graphs
|
||||
|
||||
|
||||
def make_tensor_from_type(inp_type: torch._C.TensorType):
|
||||
if inp_type.requires_grad() is not False:
|
||||
raise NotImplementedError("Tensors with requires_grad are not implemented")
|
||||
return make_tensor(
|
||||
inp_type.sizes(),
|
||||
dtype=inp_type.dtype(),
|
||||
device=inp_type.device())
|
||||
|
||||
|
||||
def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
|
||||
graph = torch._C.parse_ir(ir)
|
||||
graph.makeMultiOutputIntoTuple()
|
||||
inputs = []
|
||||
for inp in graph.inputs():
|
||||
if isinstance(inp.type(), torch._C.FloatType):
|
||||
inputs.append(.5)
|
||||
elif isinstance(inp.type(), torch._C.IntType):
|
||||
inputs.append(2)
|
||||
elif isinstance(inp.type(), torch._C.TensorType):
|
||||
inputs.append(make_tensor_from_type(inp.type()))
|
||||
else:
|
||||
raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
|
||||
|
||||
func = torch._C._create_function_from_graph("forward", graph)
|
||||
torch._C._jit_pass_erase_shape_information(func.graph)
|
||||
return (func, inputs)
|
||||
|
||||
|
||||
# TODO add support for timing on CPU
|
||||
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
|
||||
graph, _ = load_graph_and_inputs(ir)
|
||||
for _ in range(warmup_runs):
|
||||
graph(*inputs)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
torch.cuda.synchronize()
|
||||
for i in range(test_runs):
|
||||
graph(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
return start_event.elapsed_time(end_event) / test_runs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_fuser(*args, **kwargs):
|
||||
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
||||
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
||||
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
||||
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
||||
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
||||
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
||||
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
||||
|
||||
|
||||
def run_baseline_no_fusion(ir, inputs) -> float:
|
||||
with no_fuser():
|
||||
return run_test(ir, inputs)
|
||||
|
||||
|
||||
def run_nnc(ir, inputs) -> float:
|
||||
with torch.jit.fuser("fuser1"):
|
||||
return run_test(ir, inputs)
|
||||
|
||||
|
||||
def run_nvfuser(ir, inputs) -> float:
|
||||
with torch.jit.fuser("fuser2"):
|
||||
return run_test(ir, inputs)
|
||||
|
||||
|
||||
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
|
||||
for i, ir in enumerate(graphs):
|
||||
_, inputs = load_graph_and_inputs(ir)
|
||||
baseline = baseline_fn(ir, inputs)
|
||||
nvfuser = nvfuser_fn(ir, inputs)
|
||||
improvement = (baseline / nvfuser - 1) * 100
|
||||
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
|
||||
|
||||
|
||||
def run():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
|
||||
)
|
||||
parser.add_argument("filename", help="Filename of log file")
|
||||
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion")
|
||||
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion")
|
||||
parser.set_defaults(nvfuser=False)
|
||||
parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc")
|
||||
parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc")
|
||||
parser.set_defaults(nvfuser_nnc=False)
|
||||
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
|
||||
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
|
||||
parser.set_defaults(output=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
graphs = extract_ir(args.filename)
|
||||
|
||||
if args.nvfuser:
|
||||
print("NVFuser vs no fusion:")
|
||||
test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser)
|
||||
|
||||
if args.nvfuser_nnc:
|
||||
print("NVFuser vs NNC:")
|
||||
test_nvfuser(graphs, run_nnc, run_nvfuser)
|
||||
|
||||
if args.output:
|
||||
quoted = []
|
||||
for ir in graphs:
|
||||
quoted.append("\"\"\"" + ir + "\"\"\"")
|
||||
print("[" + ", ".join(quoted) + "]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
|
@ -19,6 +19,7 @@
|
|||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
|
|
@ -1709,6 +1710,13 @@ void guardFusionGroups(
|
|||
// c. restore conditional constant to non-constant for fallback
|
||||
guardFusionGroup(fusion, fusion_value_to_runtime_size);
|
||||
}
|
||||
|
||||
if (GRAPH_DEBUG_ENABLED) {
|
||||
GRAPH_DEBUG("Exporting all NVFuser fusions:");
|
||||
for (Node* fusion : fusions) {
|
||||
GRAPH_EXPORT("", fusion->g(attr::Subgraph));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewire const integer index & empty byte-typed reserve space tensor outputs,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,14 @@ TORCH_API std::ostream& operator<<(
|
|||
// pass
|
||||
#define GRAPH_DEBUG(...) \
|
||||
JIT_LOG(::torch::jit::JitLoggingLevels::GRAPH_DEBUG, __VA_ARGS__);
|
||||
// use GRAPH_EXPORT to export a graph so that the IR can be loaded by a script
|
||||
#define GRAPH_EXPORT(MSG, G) \
|
||||
JIT_LOG( \
|
||||
::torch::jit::JitLoggingLevels::GRAPH_DEBUG, \
|
||||
MSG, \
|
||||
"\n<GRAPH_EXPORT>\n", \
|
||||
(G)->toString(), \
|
||||
"</GRAPH_EXPORT>");
|
||||
|
||||
#define GRAPH_DUMP_ENABLED \
|
||||
(is_enabled(__FILE__, ::torch::jit::JitLoggingLevels::GRAPH_DUMP))
|
||||
|
|
|
|||
|
|
@ -934,6 +934,9 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
[](const TypePtr& self) {
|
||||
return self->castRaw<InterfaceType>() != nullptr;
|
||||
})
|
||||
.def(
|
||||
"requires_grad",
|
||||
[](const TypePtr& self) -> bool { return self->requires_grad(); })
|
||||
.def_property_readonly(
|
||||
"annotation_str", [](const std::shared_ptr<Type>& self) {
|
||||
return self->annotation_str();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user