[inductor] Add logging for distributed collective ops for multi‑rank diagnostics (#159190)

This change introduces structured logging of the collective communication schedule, enabling downstream tools (e.g. TLParse) to ingest and analyze per‑rank collective‐order information for multi‑rank jobs.

- Iterates over scheduler.nodes, filters for _CollectiveKernel nodes
- Extracts each op’s python_kernel_name
- Emits a structured JSON payload under the inductor_collective_schedule artifact name
- Dumps the full schedule list to collective_schedule.json via the PyTorch trace‑structured artifact
- Added comprehensive unit tests for collective schedule tracing: Created test_collective_schedule_empty() and test_collective_schedule_real() tests to verify structured trace logging works correctly for both empty collective schedules and real collective operations (like all_reduce and wait_tensor from _c10d_functional ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159190
Approved by: https://github.com/yushangdi, https://github.com/xmfan
This commit is contained in:
Sandeep Narendranath Karjala 2025-07-31 09:03:26 -07:00 committed by PyTorch MergeBot
parent 327e2ca580
commit bb62e1f769
3 changed files with 116 additions and 0 deletions

View File

@ -10,6 +10,7 @@ import shutil
import subprocess
import tempfile
import unittest.mock
from contextlib import contextmanager
import torch
import torch._dynamo.test_case
@ -21,6 +22,7 @@ from torch._inductor.test_case import TestCase
from torch._logging._internal import TorchLogsFormatter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_utils import find_free_port
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_CUDA
@ -256,6 +258,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -288,6 +291,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -327,6 +331,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -347,6 +352,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
@ -377,6 +383,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -434,6 +441,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
@ -443,6 +451,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
@ -667,6 +676,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
@ -686,6 +696,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
@ -725,6 +736,7 @@ class StructuredTraceTest(TestCase):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
@ -884,6 +896,7 @@ def forward(self, x, y):
{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "inductor_collective_schedule", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -1120,6 +1133,78 @@ def forward(self, x_1: "f32[2][1]cpu"):
f(torch.randn(i + 2 // 3, 5))
step.next_step()
@contextmanager
def _setup_collective_schedule_capture(self):
"""Helper to turn on and capture the 'inductor_collective_schedule' structured trace."""
payload_buffer = io.StringIO()
payload_handler = logging.StreamHandler(payload_buffer)
payload_handler.setLevel(logging.DEBUG)
payload_handler.setFormatter(StructuredTracePayloadFormatter())
payload_handler.addFilter(
StructuredTraceTestingFilter("inductor_collective_schedule")
)
trace_log.addHandler(payload_handler)
try:
yield payload_buffer
finally:
trace_log.removeHandler(payload_handler)
@requires_tlparse
def test_collective_schedule_empty(self):
"""Verify logging when no collective kernels are present (empty schedule)."""
with self._setup_collective_schedule_capture() as payload_buffer:
from torch._inductor.debug import log_collective_schedule
log_collective_schedule([])
self.assertIn('"inductor_collective_schedule"', self.buffer.getvalue())
self.assertEqual(json.loads(payload_buffer.getvalue()), [])
self.assertParses()
@requires_tlparse
@requires_distributed()
@torch._inductor.config.patch("fx_graph_cache", False)
def test_collective_schedule_real(self):
"""Test collective schedule with _c10d_functional ops that work with FakeStore."""
import torch.distributed as dist
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
class CollectiveModule(torch.nn.Module):
def forward(self, x):
# Use _c10d_functional ops that actually trigger collective kernels
y = torch.ops._c10d_functional.all_reduce.default(x, "sum", "0")
y = torch.ops._c10d_functional.wait_tensor.default(y)
return y * 2
try:
with self._setup_collective_schedule_capture() as payload_buffer:
torch._dynamo.reset()
mod = CollectiveModule()
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4))
# Verify collective schedule artifact was logged
self.assertIn('"inductor_collective_schedule"', self.buffer.getvalue())
payload_content = payload_buffer.getvalue().strip()
schedule = json.loads(payload_content)
self.assertIsInstance(schedule, list)
# Verify expected collective operations are present
self.assertExpectedInline(
str(schedule),
"""\
['torch.ops._c10d_functional.all_reduce_.default', 'torch.ops._c10d_functional.wait_tensor.default']\
""",
)
self.assertParses()
finally:
dist.destroy_process_group()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1523,6 +1523,9 @@ class _InProcessFxCompile(FxCompile):
},
)
# Collect and dump collective-op schedule for external diagnostics
torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes)
if (
cudagraphs
and config.triton.cudagraph_skip_dynamic_graphs

View File

@ -23,6 +23,7 @@ from torch import fx as fx
from torch._dynamo.repro.after_aot import save_graph_repro
from torch._dynamo.utils import get_debug_dir
from torch._logging import getArtifactLogger
from torch._logging._internal import trace_structured
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
@ -693,6 +694,33 @@ def log_ir_post_fusion(nodes: SchedulerNodeList) -> None:
V.debug.ir_post_fusion(nodes)
def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None:
try:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_collective_schedule",
"encoding": "json",
},
payload_fn=lambda: schedule,
)
except Exception:
log.debug(
"Failed to log inductor_collective_schedule via structured logging",
exc_info=True,
)
def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
schedule = [
getattr(op, "python_kernel_name", None)
for node in nodes
if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel)
]
_dump_collective_schedule(schedule)
@dataclasses.dataclass
class TensorMetadataHolder:
tensor_metadata: TensorMetadata