Add extra metadata (as comments) to Inductor generated code (#96581)

New output
<img width="942" alt="image" src="https://user-images.githubusercontent.com/6355099/224794006-a993a2a8-d6ff-49da-8891-7b2373030a3d.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96581
Approved by: https://github.com/ngimel, https://github.com/shunting314, https://github.com/voznesenskym
This commit is contained in:
Horace He 2023-03-13 21:26:52 +00:00 committed by PyTorch MergeBot
parent f56cb41c2e
commit 2a08a62777
7 changed files with 72 additions and 19 deletions

View File

@ -7505,6 +7505,23 @@ if HAS_CUDA and not TEST_WITH_ASAN:
func_and_kernel_torch
)
@patch.object(config, "profile_bandwidth", True)
def test_bandwidth_profiler(self):
@torch._dynamo.optimize("inductor")
def fn(x):
x = x.cos()
x = x.cos()
x = torch.mm(x, x)
x = x.sin()
x = x.relu()
return x
inp = torch.randn(4, 4, device="cuda")
code = run_and_get_triton_code(fn, inp)
fn(inp)
self.assertTrue("start_graph" in code)
self.assertTrue("end_graph" in code)
def test_split_op_with_sym(self):
def fn(x: torch.Tensor) -> torch.Tensor:
# split(tensor, sympy.Integer), split(tensor, sympy.Expr)

View File

@ -19,6 +19,7 @@ from ..ir import ReductionHint
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..utils import (
get_fused_kernel_name,
get_kernel_metadata,
instance_descriptor,
next_power_of_2,
sympy_product,
@ -1673,7 +1674,11 @@ class TritonScheduling:
compile_wrapper.splice(src_code, strip=True)
compile_wrapper.writeline("''')")
wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), kernel_path)
metadata_comment = f"# kernel path: {kernel_path}"
metadata_comment += "\n" + get_kernel_metadata(node_schedule)
wrapper.define_kernel(
kernel_name, compile_wrapper.getvalue(), metadata_comment
)
return kernel_name
def codegen_template(self, template_node, epilogue_nodes):

View File

@ -628,9 +628,9 @@ class WrapperCodeGen(CodeGen):
with output.indent():
output.writeline("benchmark_compiled_module()")
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")
def define_kernel(self, name: str, kernel: str, metadata: str = None):
metadata_comment = f"{metadata}\n" if metadata else ""
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
return

View File

@ -212,9 +212,10 @@ class triton:
# should we put op names in kernel names
# False: No special names (just triton__1, triton__2, etc.)
# "torch": Maps to the fx node in the Dynamo graph (module name, method name, etc.)
# "aten": Maps to the highest-level aten op (i.e. pre-decompositions)
descriptive_names = "aten"
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
descriptive_names = "original_aten"
# use alternate codegen for smaller reductions
persistent_reductions = True

View File

@ -233,8 +233,8 @@ def start_graph():
def end_graph():
if len(collected_calls) == 0:
return
overall_time = sum(call[1] for call in collected_calls)
overall_gb = sum(call[2] for call in collected_calls)
overall_time = sum(call[0] for call in collected_calls)
overall_gb = sum(call[1] for call in collected_calls)
cur_file = inspect.stack()[1].filename
print(f"SUMMARY ({cur_file})")
print(

View File

@ -238,13 +238,14 @@ def get_fused_kernel_name(node_schedule):
operator.or_,
[node.node.origins for node in node_schedule if hasattr(node, "node")],
)
if config.triton.descriptive_names == "aten":
if config.triton.descriptive_names == "original_aten":
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
sources = [
origin.meta["original_aten"]._overloadpacket.__name__
for origin in all_origins
if origin.op == "call_function" and "original_aten" in origin.meta
]
sources = sorted(set(sources))
elif config.triton.descriptive_names == "torch":
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
sources = []
@ -254,13 +255,39 @@ def get_fused_kernel_name(node_schedule):
sources.append(origin.meta["source_fn"])
else:
sources.append(origin.meta["source_fn"].__name__)
sources = sorted(set(sources))
elif config.triton.descriptive_names == "inductor_node":
sources = [
origin.name for origin in all_origins if origin.op == "call_function"
]
else:
raise NotImplementedError
sources = set(sources)
sources = sorted(sources)[: config.kernel_name_max_ops]
sources = sources
return "_".join(["fused"] + sources)
def get_kernel_metadata(node_schedule):
all_origins = functools.reduce(
operator.or_,
[node.node.origins for node in node_schedule if hasattr(node, "node")],
)
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
original_aten_dict = collections.defaultdict(list)
for node in inductor_nodes:
if "original_aten" in node.meta:
original_aten_dict[str(node.meta["original_aten"]._overloadpacket)].append(
node
)
metadata = [
f"# Original ATen: {', '.join(original_aten_dict.keys())}\n",
]
for original_aten, nodes in original_aten_dict.items():
metadata.append(
f"# {original_aten} => {', '.join([node.name for node in nodes])}"
)
return "\n".join(metadata)
def gather_origins(args, kwargs):
import itertools
@ -599,13 +626,16 @@ def get_num_bytes(*args):
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix=""):
import colorama
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
if ms > 0.012 and gb_per_s < 650:
return colorama.Fore.RED + info_str + colorama.Fore.RESET
else:
return info_str
try:
import colorama
if ms > 0.012 and gb_per_s < 650:
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
except ImportError:
log.warning("Colorama is not installed. Install it if you want colored output")
return info_str
def get_benchmark_name():

View File

@ -322,7 +322,7 @@ class FlopCounterMode(TorchDispatchMode):
import tabulate
tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOPS", "% Total"]
header = ["Module", "FLOP", "% Total"]
values = []
global_flops = sum(self.flop_counts['Global'].values())
global_suffix = get_suffix_str(global_flops)