[inductor] Add TLParse artifact for logging runtime of collective and compute ops (#159730)

Summary:

- debug.py: Added log_runtime_estimates() function to dump runtime estimation data as structured tlparse artifacts in JSON format
- test_structured_trace.py: Added comprehensive test coverage with testing compute and collective ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159730
Approved by: https://github.com/yushangdi
ghstack dependencies: #159190
This commit is contained in:
Sandeep Narendranath Karjala 2025-08-05 11:30:55 -07:00 committed by PyTorch MergeBot
parent 64cc6f06b1
commit 8034b2a732
4 changed files with 180 additions and 0 deletions

View File

@ -1208,6 +1208,151 @@ def forward(self, x_1: "f32[2][1]cpu"):
finally:
dist.destroy_process_group()
@contextmanager
def _setup_runtime_estimates_capture(self):
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
payload_buffer = io.StringIO()
payload_handler = logging.StreamHandler(payload_buffer)
payload_handler.setLevel(logging.DEBUG)
payload_handler.setFormatter(StructuredTracePayloadFormatter())
payload_handler.addFilter(
StructuredTraceTestingFilter("inductor_tlparse_runtime")
)
trace_log.addHandler(payload_handler)
try:
yield payload_buffer
finally:
trace_log.removeHandler(payload_handler)
@requires_tlparse
@requires_distributed()
@requires_cuda
@torch._inductor.config.patch("fx_graph_cache", False)
@torch._inductor.config.patch("log_tlparse", True)
def test_runtime_estimates_simple(self):
"""Test runtime estimates logging with simple compute and collective ops."""
import torch.distributed as dist
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
class SimpleModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
h = self.linear(x)
h = torch.relu(h)
h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
h = torch.ops._c10d_functional.wait_tensor.default(h)
return h
try:
with self._setup_runtime_estimates_capture() as payload_buffer:
torch._dynamo.reset()
mod = SimpleModule().cuda()
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
# Verify runtime estimates artifact was logged
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
payload_content = payload_buffer.getvalue().strip()
if payload_content:
data = json.loads(payload_content)
self.assertIn("ops", data)
ops = data["ops"]
# Verify runtime estimates
compute_ops = [op for op in ops if op["type"] == "compute"]
collective_ops = [op for op in ops if op["type"] == "collective"]
self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0)
# All ops should have runtime > 0 except wait_tensor can be 0
for op in ops:
if "wait_tensor" not in op["name"]:
self.assertGreater(
op["estimated_runtime_ns"],
0,
f"Op {op['name']} should have runtime > 0",
)
self.assertParses()
finally:
dist.destroy_process_group()
@requires_tlparse
@requires_distributed()
@requires_cuda
@torch._inductor.config.patch("fx_graph_cache", False)
@torch._inductor.config.patch("log_tlparse", True)
def test_runtime_estimates_mixed(self):
"""Test runtime estimates logging with mixed compute and collective sequence."""
import torch.distributed as dist
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
class MixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.LayerNorm(4)
def forward(self, x):
h = self.norm(x)
h = torch.nn.functional.gelu(h)
h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
h = torch.ops._c10d_functional.wait_tensor.default(h)
h = h * 0.5
gathered = torch.ops._c10d_functional.all_gather_into_tensor.default(
h, 2, "0"
)
gathered = torch.ops._c10d_functional.wait_tensor.default(gathered)
return gathered.sum(dim=0)
try:
with self._setup_runtime_estimates_capture() as payload_buffer:
torch._dynamo.reset()
mod = MixedModule().cuda()
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
# Verify runtime estimates artifact was logged
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
payload_content = payload_buffer.getvalue().strip()
if payload_content:
data = json.loads(payload_content)
self.assertIn("ops", data)
ops = data["ops"]
# Should have both compute and collective ops
op_types = {op["type"] for op in ops}
self.assertIn("compute", op_types)
self.assertIn("collective", op_types)
# All ops should have runtime > 0 except wait_tensor can be 0
for op in ops:
if "wait_tensor" not in op["name"]:
self.assertGreater(
op["estimated_runtime_ns"],
0,
f"Op {op['name']} should have runtime > 0",
)
self.assertParses()
finally:
dist.destroy_process_group()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1509,6 +1509,7 @@ class _InProcessFxCompile(FxCompile):
compiled_module, "runner", None
)
node_runtimes = None
if inductor_metrics_log.isEnabledFor(logging.INFO):
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
@ -1523,6 +1524,11 @@ class _InProcessFxCompile(FxCompile):
},
)
# Collect and dump op runtimes for TLParse
if config.log_tlparse:
_, _, node_runtimes = graph.count_bytes()
torch._inductor.debug.log_runtime_estimates(node_runtimes)
# Collect and dump collective-op schedule for external diagnostics
torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes)

View File

@ -741,6 +741,12 @@ worker_suppress_logging: bool = Config(
default=True,
)
# Log per-operation runtime estimates for TLParse analysis.
log_tlparse: bool = Config(
env_name_force="LOG_TLPARSE",
default=False,
)
# Flags to turn on all_reduce fusion. These 2 flags should be automatically turned
# on by DDP and should not be set by the users.
_fuse_ddp_communication = False

View File

@ -22,6 +22,7 @@ from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_co
from torch import fx as fx
from torch._dynamo.repro.after_aot import save_graph_repro
from torch._dynamo.utils import get_debug_dir
from torch._inductor import utils
from torch._logging import getArtifactLogger
from torch._logging._internal import trace_structured
from torch.fx.graph_module import GraphModule
@ -721,6 +722,28 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
_dump_collective_schedule(schedule)
def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None:
"""Log per-operation runtime estimates for TLParse."""
ops = [
{
"name": getattr(s.node, "python_kernel_name", s.get_name()),
"type": "collective" if utils.is_collective(s.node) else "compute",
"estimated_runtime_ns": runtime_ns,
}
for s, runtime_ns in node_runtimes
]
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_tlparse_runtime",
"encoding": "json",
},
payload_fn=lambda: {"ops": ops},
)
@dataclasses.dataclass
class TensorMetadataHolder:
tensor_metadata: TensorMetadata