mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Add TLParse artifact for logging runtime of collective and compute ops (#159730)
Summary: - debug.py: Added log_runtime_estimates() function to dump runtime estimation data as structured tlparse artifacts in JSON format - test_structured_trace.py: Added comprehensive test coverage with testing compute and collective ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/159730 Approved by: https://github.com/yushangdi ghstack dependencies: #159190
This commit is contained in:
parent
64cc6f06b1
commit
8034b2a732
|
|
@ -1208,6 +1208,151 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@contextmanager
|
||||
def _setup_runtime_estimates_capture(self):
|
||||
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
|
||||
payload_buffer = io.StringIO()
|
||||
payload_handler = logging.StreamHandler(payload_buffer)
|
||||
payload_handler.setLevel(logging.DEBUG)
|
||||
payload_handler.setFormatter(StructuredTracePayloadFormatter())
|
||||
payload_handler.addFilter(
|
||||
StructuredTraceTestingFilter("inductor_tlparse_runtime")
|
||||
)
|
||||
trace_log.addHandler(payload_handler)
|
||||
try:
|
||||
yield payload_buffer
|
||||
finally:
|
||||
trace_log.removeHandler(payload_handler)
|
||||
|
||||
@requires_tlparse
|
||||
@requires_distributed()
|
||||
@requires_cuda
|
||||
@torch._inductor.config.patch("fx_graph_cache", False)
|
||||
@torch._inductor.config.patch("log_tlparse", True)
|
||||
def test_runtime_estimates_simple(self):
|
||||
"""Test runtime estimates logging with simple compute and collective ops."""
|
||||
import torch.distributed as dist
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
|
||||
class SimpleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.linear(x)
|
||||
h = torch.relu(h)
|
||||
|
||||
h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
|
||||
h = torch.ops._c10d_functional.wait_tensor.default(h)
|
||||
return h
|
||||
|
||||
try:
|
||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
torch._dynamo.reset()
|
||||
|
||||
mod = SimpleModule().cuda()
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
# Verify runtime estimates artifact was logged
|
||||
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
|
||||
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
if payload_content:
|
||||
data = json.loads(payload_content)
|
||||
self.assertIn("ops", data)
|
||||
ops = data["ops"]
|
||||
|
||||
# Verify runtime estimates
|
||||
compute_ops = [op for op in ops if op["type"] == "compute"]
|
||||
collective_ops = [op for op in ops if op["type"] == "collective"]
|
||||
|
||||
self.assertTrue(len(compute_ops) > 0 or len(collective_ops) > 0)
|
||||
|
||||
# All ops should have runtime > 0 except wait_tensor can be 0
|
||||
for op in ops:
|
||||
if "wait_tensor" not in op["name"]:
|
||||
self.assertGreater(
|
||||
op["estimated_runtime_ns"],
|
||||
0,
|
||||
f"Op {op['name']} should have runtime > 0",
|
||||
)
|
||||
|
||||
self.assertParses()
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_tlparse
|
||||
@requires_distributed()
|
||||
@requires_cuda
|
||||
@torch._inductor.config.patch("fx_graph_cache", False)
|
||||
@torch._inductor.config.patch("log_tlparse", True)
|
||||
def test_runtime_estimates_mixed(self):
|
||||
"""Test runtime estimates logging with mixed compute and collective sequence."""
|
||||
import torch.distributed as dist
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
|
||||
class MixedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.norm = torch.nn.LayerNorm(4)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.norm(x)
|
||||
h = torch.nn.functional.gelu(h)
|
||||
|
||||
h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
|
||||
h = torch.ops._c10d_functional.wait_tensor.default(h)
|
||||
|
||||
h = h * 0.5
|
||||
|
||||
gathered = torch.ops._c10d_functional.all_gather_into_tensor.default(
|
||||
h, 2, "0"
|
||||
)
|
||||
gathered = torch.ops._c10d_functional.wait_tensor.default(gathered)
|
||||
|
||||
return gathered.sum(dim=0)
|
||||
|
||||
try:
|
||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
torch._dynamo.reset()
|
||||
|
||||
mod = MixedModule().cuda()
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
# Verify runtime estimates artifact was logged
|
||||
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
|
||||
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
if payload_content:
|
||||
data = json.loads(payload_content)
|
||||
self.assertIn("ops", data)
|
||||
ops = data["ops"]
|
||||
|
||||
# Should have both compute and collective ops
|
||||
op_types = {op["type"] for op in ops}
|
||||
self.assertIn("compute", op_types)
|
||||
self.assertIn("collective", op_types)
|
||||
|
||||
# All ops should have runtime > 0 except wait_tensor can be 0
|
||||
for op in ops:
|
||||
if "wait_tensor" not in op["name"]:
|
||||
self.assertGreater(
|
||||
op["estimated_runtime_ns"],
|
||||
0,
|
||||
f"Op {op['name']} should have runtime > 0",
|
||||
)
|
||||
|
||||
self.assertParses()
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1509,6 +1509,7 @@ class _InProcessFxCompile(FxCompile):
|
|||
compiled_module, "runner", None
|
||||
)
|
||||
|
||||
node_runtimes = None
|
||||
if inductor_metrics_log.isEnabledFor(logging.INFO):
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
|
|
@ -1523,6 +1524,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
},
|
||||
)
|
||||
|
||||
# Collect and dump op runtimes for TLParse
|
||||
if config.log_tlparse:
|
||||
_, _, node_runtimes = graph.count_bytes()
|
||||
torch._inductor.debug.log_runtime_estimates(node_runtimes)
|
||||
|
||||
# Collect and dump collective-op schedule for external diagnostics
|
||||
torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes)
|
||||
|
||||
|
|
|
|||
|
|
@ -741,6 +741,12 @@ worker_suppress_logging: bool = Config(
|
|||
default=True,
|
||||
)
|
||||
|
||||
# Log per-operation runtime estimates for TLParse analysis.
|
||||
log_tlparse: bool = Config(
|
||||
env_name_force="LOG_TLPARSE",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Flags to turn on all_reduce fusion. These 2 flags should be automatically turned
|
||||
# on by DDP and should not be set by the users.
|
||||
_fuse_ddp_communication = False
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_co
|
|||
from torch import fx as fx
|
||||
from torch._dynamo.repro.after_aot import save_graph_repro
|
||||
from torch._dynamo.utils import get_debug_dir
|
||||
from torch._inductor import utils
|
||||
from torch._logging import getArtifactLogger
|
||||
from torch._logging._internal import trace_structured
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
|
@ -721,6 +722,28 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
|
|||
_dump_collective_schedule(schedule)
|
||||
|
||||
|
||||
def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None:
|
||||
"""Log per-operation runtime estimates for TLParse."""
|
||||
|
||||
ops = [
|
||||
{
|
||||
"name": getattr(s.node, "python_kernel_name", s.get_name()),
|
||||
"type": "collective" if utils.is_collective(s.node) else "compute",
|
||||
"estimated_runtime_ns": runtime_ns,
|
||||
}
|
||||
for s, runtime_ns in node_runtimes
|
||||
]
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_tlparse_runtime",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: {"ops": ops},
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorMetadataHolder:
|
||||
tensor_metadata: TensorMetadata
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user