Extend Graph Export to NNC, extend script to support CPU (#74076)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74076

Extends the repro script to cpu and NNC. As in file:
Usage:
```
1. Run your script and pipe into a log file
  PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python3 my_test.py &> log.txt
2. Run log_extract:
  log_extract.py log.txt --baseline --nnc
```

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D34946883

Pulled By: eellison

fbshipit-source-id: 644012dbbca0b490820ef83e761c06b0dd009e52
(cherry picked from commit 5256c8f3ff8545033d1335cc96d34194abda1370)
This commit is contained in:
Elias Ellison 2022-03-29 11:32:31 -07:00 committed by PyTorch MergeBot
parent 9c4a63787b
commit c90be037b4
2 changed files with 68 additions and 30 deletions

View File

@ -1,10 +1,11 @@
from contextlib import contextmanager from contextlib import contextmanager
from torch.testing import make_tensor from torch.testing import make_tensor
from typing import Any, List, Tuple from typing import Any, List, Tuple, Callable
import argparse import argparse
import random import random
import torch import torch
import traceback import traceback
import time
''' '''
Usage: Usage:
@ -66,6 +67,26 @@ def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
torch._C._jit_pass_erase_shape_information(func.graph) torch._C._jit_pass_erase_shape_information(func.graph)
return (func, inputs) return (func, inputs)
def time_cuda(fn, inputs, test_runs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
torch.cuda.synchronize()
for i in range(test_runs):
fn(*inputs)
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / test_runs
def time_cpu(fn, inputs, test_runs):
s = time.perf_counter()
for _ in range(test_runs):
fn(*inputs)
e = time.perf_counter()
return (e - s) / test_runs
# TODO add support for timing on CPU # TODO add support for timing on CPU
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
@ -73,18 +94,15 @@ def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
for _ in range(warmup_runs): for _ in range(warmup_runs):
graph(*inputs) graph(*inputs)
start_event = torch.cuda.Event(enable_timing=True) is_cpu = None
end_event = torch.cuda.Event(enable_timing=True) for input in inputs:
torch.cuda.synchronize() if isinstance(input, torch.Tensor):
start_event.record() is_cpu = input.device.type == "cpu"
torch.cuda.synchronize() break
for i in range(test_runs): assert is_cpu != None
graph(*inputs)
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / test_runs
out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
return out
@contextmanager @contextmanager
def no_fuser(*args, **kwargs): def no_fuser(*args, **kwargs):
@ -122,16 +140,25 @@ def run_nvfuser(ir, inputs) -> float:
return run_test(ir, inputs) return run_test(ir, inputs)
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn): def test_runners(graphs: List[str], runners: List[Tuple[str, Callable]]):
for i, ir in enumerate(graphs): for i, ir in enumerate(graphs):
_, inputs = load_graph_and_inputs(ir) _, inputs = load_graph_and_inputs(ir)
print(f"Running Graph {ir}")
prev_result = None
prev_runner_name = None
for runner in runners:
runner_name, runner_fn = runner
try: try:
baseline = baseline_fn(ir, inputs) result = runner_fn(ir, inputs)
nvfuser = nvfuser_fn(ir, inputs) if prev_result:
improvement = (baseline / nvfuser - 1) * 100 improvement = (prev_result / result - 1) * 100
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%") print(f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%")
else:
print(f"{runner_name} : {result:.6f} ms")
prev_result = result
prev_runner_name = runner_name
except RuntimeError: except RuntimeError:
print(f" Graph {i} failed:", traceback.format_exc()) print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
def run(): def run():
@ -139,12 +166,17 @@ def run():
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR" description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
) )
parser.add_argument("filename", help="Filename of log file") parser.add_argument("filename", help="Filename of log file")
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion") parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser")
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion") parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser")
parser.set_defaults(nvfuser=False) parser.set_defaults(nvfuser=False)
parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc") parser.add_argument("--nnc", dest="nnc", action="store_true", help="benchmark nnc")
parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc") parser.add_argument("--no-nnc", dest="nnc", action="store_false", help="DON'T benchmark nnc")
parser.set_defaults(nvfuser_nnc=False) parser.set_defaults(nnc=False)
parser.add_argument("--baseline", dest="baseline", action="store_true", help="benchmark baseline")
parser.add_argument("--no-baseline", dest="baseline", action="store_false", help="DON'T benchmark baseline")
parser.set_defaults(baseline=False)
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR") parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR") parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
parser.set_defaults(output=False) parser.set_defaults(output=False)
@ -152,13 +184,15 @@ def run():
args = parser.parse_args() args = parser.parse_args()
graphs = extract_ir(args.filename) graphs = extract_ir(args.filename)
options = []
if args.baseline:
options.append(("Baseline no fusion", run_baseline_no_fusion))
if args.nnc:
options.append(("NNC", run_nnc))
if args.nvfuser: if args.nvfuser:
print("NVFuser vs no fusion:") options.append(("NVFuser", run_nvfuser))
test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser)
if args.nvfuser_nnc: test_runners(graphs, options)
print("NVFuser vs NNC:")
test_nvfuser(graphs, run_nnc, run_nvfuser)
if args.output: if args.output:
quoted = [] quoted = []

View File

@ -745,6 +745,10 @@ class TensorExprFuser {
} }
// Cleanup the subgraph from duplicated constants while we're at it. // Cleanup the subgraph from duplicated constants while we're at it.
ConstantPooling(subgraph); ConstantPooling(subgraph);
if (GRAPH_DEBUG_ENABLED) {
GRAPH_EXPORT("", subgraph);
}
return false; return false;
} }