mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add extra metadata (as comments) to Inductor generated code (#96581)
New output <img width="942" alt="image" src="https://user-images.githubusercontent.com/6355099/224794006-a993a2a8-d6ff-49da-8891-7b2373030a3d.png"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/96581 Approved by: https://github.com/ngimel, https://github.com/shunting314, https://github.com/voznesenskym
This commit is contained in:
parent
f56cb41c2e
commit
2a08a62777
|
|
@ -7505,6 +7505,23 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||
func_and_kernel_torch
|
||||
)
|
||||
|
||||
@patch.object(config, "profile_bandwidth", True)
|
||||
def test_bandwidth_profiler(self):
|
||||
@torch._dynamo.optimize("inductor")
|
||||
def fn(x):
|
||||
x = x.cos()
|
||||
x = x.cos()
|
||||
x = torch.mm(x, x)
|
||||
x = x.sin()
|
||||
x = x.relu()
|
||||
return x
|
||||
|
||||
inp = torch.randn(4, 4, device="cuda")
|
||||
code = run_and_get_triton_code(fn, inp)
|
||||
fn(inp)
|
||||
self.assertTrue("start_graph" in code)
|
||||
self.assertTrue("end_graph" in code)
|
||||
|
||||
def test_split_op_with_sym(self):
|
||||
def fn(x: torch.Tensor) -> torch.Tensor:
|
||||
# split(tensor, sympy.Integer), split(tensor, sympy.Expr)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from ..ir import ReductionHint
|
|||
from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
from ..utils import (
|
||||
get_fused_kernel_name,
|
||||
get_kernel_metadata,
|
||||
instance_descriptor,
|
||||
next_power_of_2,
|
||||
sympy_product,
|
||||
|
|
@ -1673,7 +1674,11 @@ class TritonScheduling:
|
|||
compile_wrapper.splice(src_code, strip=True)
|
||||
compile_wrapper.writeline("''')")
|
||||
|
||||
wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), kernel_path)
|
||||
metadata_comment = f"# kernel path: {kernel_path}"
|
||||
metadata_comment += "\n" + get_kernel_metadata(node_schedule)
|
||||
wrapper.define_kernel(
|
||||
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
||||
)
|
||||
return kernel_name
|
||||
|
||||
def codegen_template(self, template_node, epilogue_nodes):
|
||||
|
|
|
|||
|
|
@ -628,9 +628,9 @@ class WrapperCodeGen(CodeGen):
|
|||
with output.indent():
|
||||
output.writeline("benchmark_compiled_module()")
|
||||
|
||||
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
|
||||
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
|
||||
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")
|
||||
def define_kernel(self, name: str, kernel: str, metadata: str = None):
|
||||
metadata_comment = f"{metadata}\n" if metadata else ""
|
||||
self.header.splice(f"\n\n{metadata_comment}{name} = {kernel}")
|
||||
|
||||
def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
|
||||
return
|
||||
|
|
|
|||
|
|
@ -212,9 +212,10 @@ class triton:
|
|||
|
||||
# should we put op names in kernel names
|
||||
# False: No special names (just triton__1, triton__2, etc.)
|
||||
# "torch": Maps to the fx node in the Dynamo graph (module name, method name, etc.)
|
||||
# "aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
||||
descriptive_names = "aten"
|
||||
# "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
|
||||
# "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
|
||||
# "inductor_node": Maps to the node name in the FX graph passed to Inductor
|
||||
descriptive_names = "original_aten"
|
||||
|
||||
# use alternate codegen for smaller reductions
|
||||
persistent_reductions = True
|
||||
|
|
|
|||
|
|
@ -233,8 +233,8 @@ def start_graph():
|
|||
def end_graph():
|
||||
if len(collected_calls) == 0:
|
||||
return
|
||||
overall_time = sum(call[1] for call in collected_calls)
|
||||
overall_gb = sum(call[2] for call in collected_calls)
|
||||
overall_time = sum(call[0] for call in collected_calls)
|
||||
overall_gb = sum(call[1] for call in collected_calls)
|
||||
cur_file = inspect.stack()[1].filename
|
||||
print(f"SUMMARY ({cur_file})")
|
||||
print(
|
||||
|
|
|
|||
|
|
@ -238,13 +238,14 @@ def get_fused_kernel_name(node_schedule):
|
|||
operator.or_,
|
||||
[node.node.origins for node in node_schedule if hasattr(node, "node")],
|
||||
)
|
||||
if config.triton.descriptive_names == "aten":
|
||||
if config.triton.descriptive_names == "original_aten":
|
||||
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
|
||||
sources = [
|
||||
origin.meta["original_aten"]._overloadpacket.__name__
|
||||
for origin in all_origins
|
||||
if origin.op == "call_function" and "original_aten" in origin.meta
|
||||
]
|
||||
sources = sorted(set(sources))
|
||||
elif config.triton.descriptive_names == "torch":
|
||||
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
|
||||
sources = []
|
||||
|
|
@ -254,13 +255,39 @@ def get_fused_kernel_name(node_schedule):
|
|||
sources.append(origin.meta["source_fn"])
|
||||
else:
|
||||
sources.append(origin.meta["source_fn"].__name__)
|
||||
sources = sorted(set(sources))
|
||||
elif config.triton.descriptive_names == "inductor_node":
|
||||
sources = [
|
||||
origin.name for origin in all_origins if origin.op == "call_function"
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
sources = set(sources)
|
||||
sources = sorted(sources)[: config.kernel_name_max_ops]
|
||||
sources = sources
|
||||
return "_".join(["fused"] + sources)
|
||||
|
||||
|
||||
def get_kernel_metadata(node_schedule):
|
||||
all_origins = functools.reduce(
|
||||
operator.or_,
|
||||
[node.node.origins for node in node_schedule if hasattr(node, "node")],
|
||||
)
|
||||
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
|
||||
original_aten_dict = collections.defaultdict(list)
|
||||
for node in inductor_nodes:
|
||||
if "original_aten" in node.meta:
|
||||
original_aten_dict[str(node.meta["original_aten"]._overloadpacket)].append(
|
||||
node
|
||||
)
|
||||
metadata = [
|
||||
f"# Original ATen: {', '.join(original_aten_dict.keys())}\n",
|
||||
]
|
||||
for original_aten, nodes in original_aten_dict.items():
|
||||
metadata.append(
|
||||
f"# {original_aten} => {', '.join([node.name for node in nodes])}"
|
||||
)
|
||||
return "\n".join(metadata)
|
||||
|
||||
|
||||
def gather_origins(args, kwargs):
|
||||
import itertools
|
||||
|
||||
|
|
@ -599,13 +626,16 @@ def get_num_bytes(*args):
|
|||
|
||||
|
||||
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix=""):
|
||||
import colorama
|
||||
|
||||
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
return colorama.Fore.RED + info_str + colorama.Fore.RESET
|
||||
else:
|
||||
return info_str
|
||||
try:
|
||||
import colorama
|
||||
|
||||
if ms > 0.012 and gb_per_s < 650:
|
||||
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
|
||||
except ImportError:
|
||||
log.warning("Colorama is not installed. Install it if you want colored output")
|
||||
|
||||
return info_str
|
||||
|
||||
|
||||
def get_benchmark_name():
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class FlopCounterMode(TorchDispatchMode):
|
|||
|
||||
import tabulate
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
header = ["Module", "FLOPS", "% Total"]
|
||||
header = ["Module", "FLOP", "% Total"]
|
||||
values = []
|
||||
global_flops = sum(self.flop_counts['Global'].values())
|
||||
global_suffix = get_suffix_str(global_flops)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user