mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extend Graph Export to NNC, extend script to support CPU (#74076)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74076 Extends the repro script to cpu and NNC. As in file: Usage: ``` 1. Run your script and pipe into a log file PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python3 my_test.py &> log.txt 2. Run log_extract: log_extract.py log.txt --baseline --nnc ``` Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D34946883 Pulled By: eellison fbshipit-source-id: 644012dbbca0b490820ef83e761c06b0dd009e52 (cherry picked from commit 5256c8f3ff8545033d1335cc96d34194abda1370)
This commit is contained in:
parent
9c4a63787b
commit
c90be037b4
|
|
@ -1,10 +1,11 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple, Callable
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
|
import time
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Usage:
|
Usage:
|
||||||
|
|
@ -66,6 +67,26 @@ def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
|
||||||
torch._C._jit_pass_erase_shape_information(func.graph)
|
torch._C._jit_pass_erase_shape_information(func.graph)
|
||||||
return (func, inputs)
|
return (func, inputs)
|
||||||
|
|
||||||
|
def time_cuda(fn, inputs, test_runs):
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for i in range(test_runs):
|
||||||
|
fn(*inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return start_event.elapsed_time(end_event) / test_runs
|
||||||
|
|
||||||
|
def time_cpu(fn, inputs, test_runs):
|
||||||
|
s = time.perf_counter()
|
||||||
|
for _ in range(test_runs):
|
||||||
|
fn(*inputs)
|
||||||
|
e = time.perf_counter()
|
||||||
|
return (e - s) / test_runs
|
||||||
|
|
||||||
|
|
||||||
# TODO add support for timing on CPU
|
# TODO add support for timing on CPU
|
||||||
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
|
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
|
||||||
|
|
@ -73,18 +94,15 @@ def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
|
||||||
for _ in range(warmup_runs):
|
for _ in range(warmup_runs):
|
||||||
graph(*inputs)
|
graph(*inputs)
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
is_cpu = None
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
for input in inputs:
|
||||||
torch.cuda.synchronize()
|
if isinstance(input, torch.Tensor):
|
||||||
start_event.record()
|
is_cpu = input.device.type == "cpu"
|
||||||
torch.cuda.synchronize()
|
break
|
||||||
for i in range(test_runs):
|
assert is_cpu != None
|
||||||
graph(*inputs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return start_event.elapsed_time(end_event) / test_runs
|
|
||||||
|
|
||||||
|
out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
|
||||||
|
return out
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_fuser(*args, **kwargs):
|
def no_fuser(*args, **kwargs):
|
||||||
|
|
@ -122,16 +140,25 @@ def run_nvfuser(ir, inputs) -> float:
|
||||||
return run_test(ir, inputs)
|
return run_test(ir, inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
|
def test_runners(graphs: List[str], runners: List[Tuple[str, Callable]]):
|
||||||
for i, ir in enumerate(graphs):
|
for i, ir in enumerate(graphs):
|
||||||
_, inputs = load_graph_and_inputs(ir)
|
_, inputs = load_graph_and_inputs(ir)
|
||||||
|
print(f"Running Graph {ir}")
|
||||||
|
prev_result = None
|
||||||
|
prev_runner_name = None
|
||||||
|
for runner in runners:
|
||||||
|
runner_name, runner_fn = runner
|
||||||
try:
|
try:
|
||||||
baseline = baseline_fn(ir, inputs)
|
result = runner_fn(ir, inputs)
|
||||||
nvfuser = nvfuser_fn(ir, inputs)
|
if prev_result:
|
||||||
improvement = (baseline / nvfuser - 1) * 100
|
improvement = (prev_result / result - 1) * 100
|
||||||
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
|
print(f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%")
|
||||||
|
else:
|
||||||
|
print(f"{runner_name} : {result:.6f} ms")
|
||||||
|
prev_result = result
|
||||||
|
prev_runner_name = runner_name
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(f" Graph {i} failed:", traceback.format_exc())
|
print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
@ -139,12 +166,17 @@ def run():
|
||||||
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
|
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
|
||||||
)
|
)
|
||||||
parser.add_argument("filename", help="Filename of log file")
|
parser.add_argument("filename", help="Filename of log file")
|
||||||
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion")
|
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser")
|
||||||
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion")
|
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser")
|
||||||
parser.set_defaults(nvfuser=False)
|
parser.set_defaults(nvfuser=False)
|
||||||
parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc")
|
parser.add_argument("--nnc", dest="nnc", action="store_true", help="benchmark nnc")
|
||||||
parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc")
|
parser.add_argument("--no-nnc", dest="nnc", action="store_false", help="DON'T benchmark nnc")
|
||||||
parser.set_defaults(nvfuser_nnc=False)
|
parser.set_defaults(nnc=False)
|
||||||
|
|
||||||
|
parser.add_argument("--baseline", dest="baseline", action="store_true", help="benchmark baseline")
|
||||||
|
parser.add_argument("--no-baseline", dest="baseline", action="store_false", help="DON'T benchmark baseline")
|
||||||
|
parser.set_defaults(baseline=False)
|
||||||
|
|
||||||
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
|
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
|
||||||
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
|
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
|
||||||
parser.set_defaults(output=False)
|
parser.set_defaults(output=False)
|
||||||
|
|
@ -152,13 +184,15 @@ def run():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
graphs = extract_ir(args.filename)
|
graphs = extract_ir(args.filename)
|
||||||
|
|
||||||
|
options = []
|
||||||
|
if args.baseline:
|
||||||
|
options.append(("Baseline no fusion", run_baseline_no_fusion))
|
||||||
|
if args.nnc:
|
||||||
|
options.append(("NNC", run_nnc))
|
||||||
if args.nvfuser:
|
if args.nvfuser:
|
||||||
print("NVFuser vs no fusion:")
|
options.append(("NVFuser", run_nvfuser))
|
||||||
test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser)
|
|
||||||
|
|
||||||
if args.nvfuser_nnc:
|
test_runners(graphs, options)
|
||||||
print("NVFuser vs NNC:")
|
|
||||||
test_nvfuser(graphs, run_nnc, run_nvfuser)
|
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
quoted = []
|
quoted = []
|
||||||
|
|
|
||||||
|
|
@ -745,6 +745,10 @@ class TensorExprFuser {
|
||||||
}
|
}
|
||||||
// Cleanup the subgraph from duplicated constants while we're at it.
|
// Cleanup the subgraph from duplicated constants while we're at it.
|
||||||
ConstantPooling(subgraph);
|
ConstantPooling(subgraph);
|
||||||
|
|
||||||
|
if (GRAPH_DEBUG_ENABLED) {
|
||||||
|
GRAPH_EXPORT("", subgraph);
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user