diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 5897c129b26..6e49f288f5f 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -28,7 +28,6 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton if torch.distributed.is_available(): from torch.testing._internal.distributed.fake_pg import FakeStore - HAS_TLPARSE = shutil.which("tlparse") is not None requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") requires_distributed = functools.partial( @@ -1198,13 +1197,13 @@ def forward(self, x_1: "f32[2][1]cpu"): @contextmanager def _setup_runtime_estimates_capture(self): - """Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace.""" + """Helper to turn on and capture the combined 'inductor_runtime_and_tensor_meta' structured trace.""" payload_buffer = io.StringIO() payload_handler = logging.StreamHandler(payload_buffer) payload_handler.setLevel(logging.DEBUG) payload_handler.setFormatter(StructuredTracePayloadFormatter()) payload_handler.addFilter( - StructuredTraceTestingFilter("inductor_tlparse_runtime") + StructuredTraceTestingFilter("inductor_runtime_and_tensor_meta") ) trace_log.addHandler(payload_handler) try: @@ -1245,8 +1244,10 @@ def forward(self, x_1: "f32[2][1]cpu"): compiled = torch.compile(mod, backend="inductor") compiled(torch.randn(4, 4, device="cuda")) - # Verify runtime estimates artifact was logged - self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + # Verify runtime + tensor meta artifact was logged + self.assertIn( + '"inductor_runtime_and_tensor_meta"', self.buffer.getvalue() + ) payload_content = payload_buffer.getvalue().strip() if payload_content: @@ -1310,8 +1311,10 @@ def forward(self, x_1: "f32[2][1]cpu"): compiled = torch.compile(mod, backend="inductor") compiled(torch.randn(4, 4, device="cuda")) - # Verify runtime estimates artifact was logged - self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue()) + # Verify artifact was logged + self.assertIn( + '"inductor_runtime_and_tensor_meta"', self.buffer.getvalue() + ) payload_content = payload_buffer.getvalue().strip() if payload_content: @@ -1333,6 +1336,145 @@ def forward(self, x_1: "f32[2][1]cpu"): finally: dist.destroy_process_group() + @requires_tlparse + @requires_distributed() + @requires_cuda_and_triton + @torch._inductor.config.patch("fx_graph_cache", False) + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging_multiple_ops(self): + import torch.distributed as dist + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + + class Mixed(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + y = torch.relu(self.linear(x)) + y = torch.ops._c10d_functional.all_reduce.default(y, "sum", "0") + y = torch.ops._c10d_functional.wait_tensor.default(y) + return y + 1 + + try: + with self._setup_runtime_estimates_capture() as payload_buffer: + torch._dynamo.reset() + mod = Mixed().cuda() + compiled = torch.compile(mod, backend="inductor") + compiled(torch.randn(4, 4, device="cuda")) + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + types = sorted({op.get("type") for op in data.get("ops", [])}) + self.assertExpectedInline( + str(types), """['collective', 'compute']""" + ) + self.assertParses() + finally: + dist.destroy_process_group() + + @requires_tlparse + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging(self): + """Emit unified runtime+tensor-metadata artifact and assert a stable simplified JSON inline.""" + with self._setup_runtime_estimates_capture() as payload_buffer: + + def f(x): + y = x.transpose(0, 1) + z = y.mean(dim=0) + w = z.to(torch.float16) + return w + + compiled = torch.compile(f, backend="inductor", fullgraph=True) + compiled(torch.ones(2, 3)) + + # Verify artifact was logged + self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()) + + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + ops = data.get("ops", []) + + simplified_ops = [] + for op in ops: + outs = [ + { + "shape": out.get("shape", []), + "stride": out.get("stride", []), + "dtype": out.get("dtype", None), + } + for out in op.get("outputs", []) + ] + if outs: + simplified_ops.append( + { + "type": op.get("type", ""), + "outputs": outs, + } + ) + + self.assertExpectedInline( + {"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []}, + """{'ops': [{'type': 'compute', 'outputs': [{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}""", + ) + + self.assertParses() + + @requires_tlparse + @torch._inductor.config.patch("log_tlparse", True) + def test_tensor_metadata_logging_dynamic_shapes(self): + """Same as test_tensor_metadata_logging, but with dynamic shapes enabled to cover to_size_hints.""" + with self._setup_runtime_estimates_capture() as payload_buffer: + + def f(x): + y = x.transpose(0, 1) + z = y.mean(dim=0) + w = z.to(torch.float16) + return w + + compiled = torch.compile(f, backend="inductor", dynamic=True) + compiled(torch.ones(2, 3)) + + # Verify artifact was logged + self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()) + + payload = payload_buffer.getvalue().strip() + if payload: + data = json.loads(payload) + ops = data.get("ops", []) + + simplified_ops = [] + for op in ops: + outs = [ + { + "shape": out.get("shape", []), + "stride": out.get("stride", []), + "dtype": out.get("dtype", None), + } + for out in op.get("outputs", []) + ] + if outs: + simplified_ops.append( + { + "type": op.get("type", ""), + "outputs": outs, + } + ) + + self.assertExpectedInline( + {"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []}, + ( + "{'ops': [{'type': 'compute', 'outputs': [" + "{'shape': [2], 'stride': [1], 'dtype': 'float32'}, " + "{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}" + ), + ) + + self.assertParses() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 115e0efcc5d..3d614d6795b 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1526,10 +1526,10 @@ class _InProcessFxCompile(FxCompile): }, ) - # Collect and dump op runtimes for TLParse + # Collect and dump op runtimes and tensor metadata for TLParse if config.log_tlparse: _, _, node_runtimes = graph.count_bytes() - torch._inductor.debug.log_runtime_estimates(node_runtimes) + torch._inductor.debug.log_runtime_and_tensor_meta(node_runtimes) # Collect and dump collective-op schedule for external diagnostics torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 1fbb69563dc..a31d56bd385 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -737,26 +737,68 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: _dump_collective_schedule(schedule) -def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None: - """Log per-operation runtime estimates for TLParse.""" +def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None: + """Log per-op runtime estimates and output tensor metadata for TLParse.""" - ops = [ - { - "name": getattr(s.node, "python_kernel_name", s.get_name()), - "type": "collective" if utils.is_collective(s.node) else "compute", - "estimated_runtime_ns": runtime_ns, - } - for s, runtime_ns in node_runtimes - ] + try: + to_size_hints = V.graph.sizevars.size_hints - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_tlparse_runtime", - "encoding": "json", - }, - payload_fn=lambda: {"ops": ops}, - ) + def to_list(x: Optional[Sequence[Any]]) -> list[Any]: + return list(to_size_hints(x)) if x is not None else [] + + def dtype_to_str(dtype: Any) -> Optional[str]: + if dtype is None: + return None + s = str(dtype) + s = s.removeprefix("torch.") + return s + + ops: list[dict[str, Any]] = [] + for s, runtime_ns in node_runtimes: + name = getattr(s.node, "python_kernel_name", s.get_name()) + op_type = "collective" if utils.is_collective(s.node) else "compute" + + # Build outputs metadata if available + outputs: list[dict[str, Any]] = [] + try: + for buf in s.get_outputs(): + irnode = buf.node + shape = irnode.maybe_get_size() + stride = ( + irnode.get_stride() + if isinstance(irnode.layout, ir.Layout) + else None + ) + dtype = irnode.maybe_get_dtype() + outputs.append( + { + "shape": to_list(shape), + "stride": to_list(stride), + "dtype": dtype_to_str(dtype), + } + ) + except Exception: + pass + + ops.append( + { + "name": name, + "type": op_type, + "estimated_runtime_ns": runtime_ns, + "outputs": outputs, + } + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_runtime_and_tensor_meta", + "encoding": "json", + }, + payload_fn=lambda: {"ops": ops}, + ) + except Exception: + log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True) @dataclasses.dataclass