[inductor] TLParse tensor metadata logging + test (#160132)

Summary:
- Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation.

Testing:
- Add test to verify structure and contents of tlparse artifiact

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160132
Approved by: https://github.com/xmfan
ghstack dependencies: #160260
This commit is contained in:
Sandeep Narendranath Karjala 2025-08-15 14:35:43 -07:00 committed by PyTorch MergeBot
parent 8fe4b3f848
commit 2603e40be5
3 changed files with 212 additions and 27 deletions

View File

@ -25,10 +25,10 @@ from torch.testing._internal.common_utils import find_free_port
from torch.testing._internal.triton_utils import requires_cuda_and_triton
requires_cuda_and_triton = unittest.skipUnless(HAS_CUDA, "requires cuda")
if torch.distributed.is_available():
from torch.testing._internal.distributed.fake_pg import FakeStore
HAS_TLPARSE = shutil.which("tlparse") is not None
requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse")
requires_distributed = functools.partial(
@ -1198,13 +1198,13 @@ def forward(self, x_1: "f32[2][1]cpu"):
@contextmanager
def _setup_runtime_estimates_capture(self):
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
"""Helper to turn on and capture the combined 'inductor_runtime_and_tensor_meta' structured trace."""
payload_buffer = io.StringIO()
payload_handler = logging.StreamHandler(payload_buffer)
payload_handler.setLevel(logging.DEBUG)
payload_handler.setFormatter(StructuredTracePayloadFormatter())
payload_handler.addFilter(
StructuredTraceTestingFilter("inductor_tlparse_runtime")
StructuredTraceTestingFilter("inductor_runtime_and_tensor_meta")
)
trace_log.addHandler(payload_handler)
try:
@ -1245,8 +1245,10 @@ def forward(self, x_1: "f32[2][1]cpu"):
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
# Verify runtime estimates artifact was logged
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
# Verify runtime + tensor meta artifact was logged
self.assertIn(
'"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()
)
payload_content = payload_buffer.getvalue().strip()
if payload_content:
@ -1310,8 +1312,10 @@ def forward(self, x_1: "f32[2][1]cpu"):
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
# Verify runtime estimates artifact was logged
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
# Verify artifact was logged
self.assertIn(
'"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()
)
payload_content = payload_buffer.getvalue().strip()
if payload_content:
@ -1333,6 +1337,145 @@ def forward(self, x_1: "f32[2][1]cpu"):
finally:
dist.destroy_process_group()
@requires_tlparse
@requires_distributed()
@requires_cuda_and_triton
@torch._inductor.config.patch("fx_graph_cache", False)
@torch._inductor.config.patch("log_tlparse", True)
def test_tensor_metadata_logging_multiple_ops(self):
import torch.distributed as dist
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
class Mixed(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
y = torch.relu(self.linear(x))
y = torch.ops._c10d_functional.all_reduce.default(y, "sum", "0")
y = torch.ops._c10d_functional.wait_tensor.default(y)
return y + 1
try:
with self._setup_runtime_estimates_capture() as payload_buffer:
torch._dynamo.reset()
mod = Mixed().cuda()
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
payload = payload_buffer.getvalue().strip()
if payload:
data = json.loads(payload)
types = sorted({op.get("type") for op in data.get("ops", [])})
self.assertExpectedInline(
str(types), """['collective', 'compute']"""
)
self.assertParses()
finally:
dist.destroy_process_group()
@requires_tlparse
@torch._inductor.config.patch("log_tlparse", True)
def test_tensor_metadata_logging(self):
"""Emit unified runtime+tensor-metadata artifact and assert a stable simplified JSON inline."""
with self._setup_runtime_estimates_capture() as payload_buffer:
def f(x):
y = x.transpose(0, 1)
z = y.mean(dim=0)
w = z.to(torch.float16)
return w
compiled = torch.compile(f, backend="inductor", fullgraph=True)
compiled(torch.ones(2, 3))
# Verify artifact was logged
self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue())
payload = payload_buffer.getvalue().strip()
if payload:
data = json.loads(payload)
ops = data.get("ops", [])
simplified_ops = []
for op in ops:
outs = [
{
"shape": out.get("shape", []),
"stride": out.get("stride", []),
"dtype": out.get("dtype", None),
}
for out in op.get("outputs", [])
]
if outs:
simplified_ops.append(
{
"type": op.get("type", ""),
"outputs": outs,
}
)
self.assertExpectedInline(
{"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []},
"""{'ops': [{'type': 'compute', 'outputs': [{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}""",
)
self.assertParses()
@requires_tlparse
@torch._inductor.config.patch("log_tlparse", True)
def test_tensor_metadata_logging_dynamic_shapes(self):
"""Same as test_tensor_metadata_logging, but with dynamic shapes enabled to cover to_size_hints."""
with self._setup_runtime_estimates_capture() as payload_buffer:
def f(x):
y = x.transpose(0, 1)
z = y.mean(dim=0)
w = z.to(torch.float16)
return w
compiled = torch.compile(f, backend="inductor", dynamic=True)
compiled(torch.ones(2, 3))
# Verify artifact was logged
self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue())
payload = payload_buffer.getvalue().strip()
if payload:
data = json.loads(payload)
ops = data.get("ops", [])
simplified_ops = []
for op in ops:
outs = [
{
"shape": out.get("shape", []),
"stride": out.get("stride", []),
"dtype": out.get("dtype", None),
}
for out in op.get("outputs", [])
]
if outs:
simplified_ops.append(
{
"type": op.get("type", ""),
"outputs": outs,
}
)
self.assertExpectedInline(
{"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []},
(
"{'ops': [{'type': 'compute', 'outputs': ["
"{'shape': [2], 'stride': [1], 'dtype': 'float32'}, "
"{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}"
),
)
self.assertParses()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1526,10 +1526,10 @@ class _InProcessFxCompile(FxCompile):
},
)
# Collect and dump op runtimes for TLParse
# Collect and dump op runtimes and tensor metadata for TLParse
if config.log_tlparse:
_, _, node_runtimes = graph.count_bytes()
torch._inductor.debug.log_runtime_estimates(node_runtimes)
torch._inductor.debug.log_runtime_and_tensor_meta(node_runtimes)
# Collect and dump collective-op schedule for external diagnostics
torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes)

View File

@ -737,26 +737,68 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
_dump_collective_schedule(schedule)
def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None:
"""Log per-operation runtime estimates for TLParse."""
def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None:
"""Log per-op runtime estimates and output tensor metadata for TLParse."""
ops = [
{
"name": getattr(s.node, "python_kernel_name", s.get_name()),
"type": "collective" if utils.is_collective(s.node) else "compute",
"estimated_runtime_ns": runtime_ns,
}
for s, runtime_ns in node_runtimes
]
try:
to_size_hints = V.graph.sizevars.size_hints
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_tlparse_runtime",
"encoding": "json",
},
payload_fn=lambda: {"ops": ops},
)
def to_list(x: Optional[Sequence[Any]]) -> list[Any]:
return list(to_size_hints(x)) if x is not None else []
def dtype_to_str(dtype: Any) -> Optional[str]:
if dtype is None:
return None
s = str(dtype)
s = s.removeprefix("torch.")
return s
ops: list[dict[str, Any]] = []
for s, runtime_ns in node_runtimes:
name = getattr(s.node, "python_kernel_name", s.get_name())
op_type = "collective" if utils.is_collective(s.node) else "compute"
# Build outputs metadata if available
outputs: list[dict[str, Any]] = []
try:
for buf in s.get_outputs():
irnode = buf.node
shape = irnode.maybe_get_size()
stride = (
irnode.get_stride()
if isinstance(irnode.layout, ir.Layout)
else None
)
dtype = irnode.maybe_get_dtype()
outputs.append(
{
"shape": to_list(shape),
"stride": to_list(stride),
"dtype": dtype_to_str(dtype),
}
)
except Exception:
pass
ops.append(
{
"name": name,
"type": op_type,
"estimated_runtime_ns": runtime_ns,
"outputs": outputs,
}
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_runtime_and_tensor_meta",
"encoding": "json",
},
payload_fn=lambda: {"ops": ops},
)
except Exception:
log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True)
@dataclasses.dataclass