mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] log extract tool - dump NVFuser fallbacks instead of fusion groups (#73881)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73881 NVFuser fusion groups can contain nvfuser-only ops, e.g. `prim::reshape_copy`. Previously, we couldn't get a baseline performance measurement because the nvfuser-only ops would error out on nnc- and no-fusion- runs. Instead, dump the fallback graphs, after the fallbacks are corrected into runnable fallbacks. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D34698307 Pulled By: davidberard98 fbshipit-source-id: c357b2736b789bfd347afe9c83a1b610b64881e0 (cherry picked from commit 5918d826502ff75fbc22d242844ae6435dd7d22a)
This commit is contained in:
parent
56164c07c4
commit
31b64fc3e6
|
|
@ -2,7 +2,9 @@ from contextlib import contextmanager
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
import argparse
|
import argparse
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
import traceback
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Usage:
|
Usage:
|
||||||
|
|
@ -52,9 +54,9 @@ def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
|
||||||
inputs = []
|
inputs = []
|
||||||
for inp in graph.inputs():
|
for inp in graph.inputs():
|
||||||
if isinstance(inp.type(), torch._C.FloatType):
|
if isinstance(inp.type(), torch._C.FloatType):
|
||||||
inputs.append(.5)
|
inputs.append(random.uniform(.1, 100))
|
||||||
elif isinstance(inp.type(), torch._C.IntType):
|
elif isinstance(inp.type(), torch._C.IntType):
|
||||||
inputs.append(2)
|
inputs.append(random.randint(1, 100))
|
||||||
elif isinstance(inp.type(), torch._C.TensorType):
|
elif isinstance(inp.type(), torch._C.TensorType):
|
||||||
inputs.append(make_tensor_from_type(inp.type()))
|
inputs.append(make_tensor_from_type(inp.type()))
|
||||||
else:
|
else:
|
||||||
|
|
@ -123,10 +125,13 @@ def run_nvfuser(ir, inputs) -> float:
|
||||||
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
|
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
|
||||||
for i, ir in enumerate(graphs):
|
for i, ir in enumerate(graphs):
|
||||||
_, inputs = load_graph_and_inputs(ir)
|
_, inputs = load_graph_and_inputs(ir)
|
||||||
baseline = baseline_fn(ir, inputs)
|
try:
|
||||||
nvfuser = nvfuser_fn(ir, inputs)
|
baseline = baseline_fn(ir, inputs)
|
||||||
improvement = (baseline / nvfuser - 1) * 100
|
nvfuser = nvfuser_fn(ir, inputs)
|
||||||
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
|
improvement = (baseline / nvfuser - 1) * 100
|
||||||
|
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
|
||||||
|
except RuntimeError:
|
||||||
|
print(f" Graph {i} failed:", traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -1710,11 +1710,15 @@ void guardFusionGroups(
|
||||||
// c. restore conditional constant to non-constant for fallback
|
// c. restore conditional constant to non-constant for fallback
|
||||||
guardFusionGroup(fusion, fusion_value_to_runtime_size);
|
guardFusionGroup(fusion, fusion_value_to_runtime_size);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (GRAPH_DEBUG_ENABLED) {
|
void dumpFusionGroups(std::shared_ptr<Graph>& g) {
|
||||||
GRAPH_DEBUG("Exporting all NVFuser fusions:");
|
DepthFirstGraphNodeIterator it(g);
|
||||||
for (Node* fusion : fusions) {
|
Node* n = nullptr;
|
||||||
GRAPH_EXPORT("", fusion->g(attr::Subgraph));
|
GRAPH_DEBUG("Exporting all NVFuser fusions:");
|
||||||
|
while ((n = it.next()) != nullptr) {
|
||||||
|
if (n->kind() == prim::FallbackGraph) {
|
||||||
|
GRAPH_EXPORT("", n->g(attr::Subgraph));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2305,6 +2309,8 @@ void CudaFuseGraph(std::shared_ptr<Graph>& graph) {
|
||||||
revertAliasCopyOps(graph, graph->block());
|
revertAliasCopyOps(graph, graph->block());
|
||||||
GRAPH_DEBUG("revert alias_copy ops by nvfuser: ", *graph);
|
GRAPH_DEBUG("revert alias_copy ops by nvfuser: ", *graph);
|
||||||
|
|
||||||
|
dumpFusionGroups(graph);
|
||||||
|
|
||||||
// After FuseGraph some common subexpressions may come back
|
// After FuseGraph some common subexpressions may come back
|
||||||
EliminateCommonSubexpression(graph);
|
EliminateCommonSubexpression(graph);
|
||||||
// We might have emitted a fair amount of useless shape propagating code, so
|
// We might have emitted a fair amount of useless shape propagating code, so
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user