mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] TLParse tensor metadata logging + test (#160132)
Summary: - Add TLParse artifact logging per op with output tensor shape, stride, and dtype for cross-rank aggregation. Testing: - Add test to verify structure and contents of tlparse artifiact Pull Request resolved: https://github.com/pytorch/pytorch/pull/160132 Approved by: https://github.com/xmfan ghstack dependencies: #160260
This commit is contained in:
parent
8fe4b3f848
commit
2603e40be5
|
|
@ -25,10 +25,10 @@ from torch.testing._internal.common_utils import find_free_port
|
|||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
requires_cuda_and_triton = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
if torch.distributed.is_available():
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
HAS_TLPARSE = shutil.which("tlparse") is not None
|
||||
requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse")
|
||||
requires_distributed = functools.partial(
|
||||
|
|
@ -1198,13 +1198,13 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
|
||||
@contextmanager
|
||||
def _setup_runtime_estimates_capture(self):
|
||||
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
|
||||
"""Helper to turn on and capture the combined 'inductor_runtime_and_tensor_meta' structured trace."""
|
||||
payload_buffer = io.StringIO()
|
||||
payload_handler = logging.StreamHandler(payload_buffer)
|
||||
payload_handler.setLevel(logging.DEBUG)
|
||||
payload_handler.setFormatter(StructuredTracePayloadFormatter())
|
||||
payload_handler.addFilter(
|
||||
StructuredTraceTestingFilter("inductor_tlparse_runtime")
|
||||
StructuredTraceTestingFilter("inductor_runtime_and_tensor_meta")
|
||||
)
|
||||
trace_log.addHandler(payload_handler)
|
||||
try:
|
||||
|
|
@ -1245,8 +1245,10 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
# Verify runtime estimates artifact was logged
|
||||
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
|
||||
# Verify runtime + tensor meta artifact was logged
|
||||
self.assertIn(
|
||||
'"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()
|
||||
)
|
||||
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
if payload_content:
|
||||
|
|
@ -1310,8 +1312,10 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
# Verify runtime estimates artifact was logged
|
||||
self.assertIn('"inductor_tlparse_runtime"', self.buffer.getvalue())
|
||||
# Verify artifact was logged
|
||||
self.assertIn(
|
||||
'"inductor_runtime_and_tensor_meta"', self.buffer.getvalue()
|
||||
)
|
||||
|
||||
payload_content = payload_buffer.getvalue().strip()
|
||||
if payload_content:
|
||||
|
|
@ -1333,6 +1337,145 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_tlparse
|
||||
@requires_distributed()
|
||||
@requires_cuda_and_triton
|
||||
@torch._inductor.config.patch("fx_graph_cache", False)
|
||||
@torch._inductor.config.patch("log_tlparse", True)
|
||||
def test_tensor_metadata_logging_multiple_ops(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
|
||||
class Mixed(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
y = torch.relu(self.linear(x))
|
||||
y = torch.ops._c10d_functional.all_reduce.default(y, "sum", "0")
|
||||
y = torch.ops._c10d_functional.wait_tensor.default(y)
|
||||
return y + 1
|
||||
|
||||
try:
|
||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
torch._dynamo.reset()
|
||||
mod = Mixed().cuda()
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
payload = payload_buffer.getvalue().strip()
|
||||
if payload:
|
||||
data = json.loads(payload)
|
||||
types = sorted({op.get("type") for op in data.get("ops", [])})
|
||||
self.assertExpectedInline(
|
||||
str(types), """['collective', 'compute']"""
|
||||
)
|
||||
self.assertParses()
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_tlparse
|
||||
@torch._inductor.config.patch("log_tlparse", True)
|
||||
def test_tensor_metadata_logging(self):
|
||||
"""Emit unified runtime+tensor-metadata artifact and assert a stable simplified JSON inline."""
|
||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
|
||||
def f(x):
|
||||
y = x.transpose(0, 1)
|
||||
z = y.mean(dim=0)
|
||||
w = z.to(torch.float16)
|
||||
return w
|
||||
|
||||
compiled = torch.compile(f, backend="inductor", fullgraph=True)
|
||||
compiled(torch.ones(2, 3))
|
||||
|
||||
# Verify artifact was logged
|
||||
self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue())
|
||||
|
||||
payload = payload_buffer.getvalue().strip()
|
||||
if payload:
|
||||
data = json.loads(payload)
|
||||
ops = data.get("ops", [])
|
||||
|
||||
simplified_ops = []
|
||||
for op in ops:
|
||||
outs = [
|
||||
{
|
||||
"shape": out.get("shape", []),
|
||||
"stride": out.get("stride", []),
|
||||
"dtype": out.get("dtype", None),
|
||||
}
|
||||
for out in op.get("outputs", [])
|
||||
]
|
||||
if outs:
|
||||
simplified_ops.append(
|
||||
{
|
||||
"type": op.get("type", ""),
|
||||
"outputs": outs,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
{"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []},
|
||||
"""{'ops': [{'type': 'compute', 'outputs': [{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}""",
|
||||
)
|
||||
|
||||
self.assertParses()
|
||||
|
||||
@requires_tlparse
|
||||
@torch._inductor.config.patch("log_tlparse", True)
|
||||
def test_tensor_metadata_logging_dynamic_shapes(self):
|
||||
"""Same as test_tensor_metadata_logging, but with dynamic shapes enabled to cover to_size_hints."""
|
||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
|
||||
def f(x):
|
||||
y = x.transpose(0, 1)
|
||||
z = y.mean(dim=0)
|
||||
w = z.to(torch.float16)
|
||||
return w
|
||||
|
||||
compiled = torch.compile(f, backend="inductor", dynamic=True)
|
||||
compiled(torch.ones(2, 3))
|
||||
|
||||
# Verify artifact was logged
|
||||
self.assertIn('"inductor_runtime_and_tensor_meta"', self.buffer.getvalue())
|
||||
|
||||
payload = payload_buffer.getvalue().strip()
|
||||
if payload:
|
||||
data = json.loads(payload)
|
||||
ops = data.get("ops", [])
|
||||
|
||||
simplified_ops = []
|
||||
for op in ops:
|
||||
outs = [
|
||||
{
|
||||
"shape": out.get("shape", []),
|
||||
"stride": out.get("stride", []),
|
||||
"dtype": out.get("dtype", None),
|
||||
}
|
||||
for out in op.get("outputs", [])
|
||||
]
|
||||
if outs:
|
||||
simplified_ops.append(
|
||||
{
|
||||
"type": op.get("type", ""),
|
||||
"outputs": outs,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
{"ops": simplified_ops[-1:]} if simplified_ops else {"ops": []},
|
||||
(
|
||||
"{'ops': [{'type': 'compute', 'outputs': ["
|
||||
"{'shape': [2], 'stride': [1], 'dtype': 'float32'}, "
|
||||
"{'shape': [2], 'stride': [1], 'dtype': 'float16'}]}]}"
|
||||
),
|
||||
)
|
||||
|
||||
self.assertParses()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1526,10 +1526,10 @@ class _InProcessFxCompile(FxCompile):
|
|||
},
|
||||
)
|
||||
|
||||
# Collect and dump op runtimes for TLParse
|
||||
# Collect and dump op runtimes and tensor metadata for TLParse
|
||||
if config.log_tlparse:
|
||||
_, _, node_runtimes = graph.count_bytes()
|
||||
torch._inductor.debug.log_runtime_estimates(node_runtimes)
|
||||
torch._inductor.debug.log_runtime_and_tensor_meta(node_runtimes)
|
||||
|
||||
# Collect and dump collective-op schedule for external diagnostics
|
||||
torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes)
|
||||
|
|
|
|||
|
|
@ -737,26 +737,68 @@ def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
|
|||
_dump_collective_schedule(schedule)
|
||||
|
||||
|
||||
def log_runtime_estimates(node_runtimes: Sequence[tuple[Any, float]]) -> None:
|
||||
"""Log per-operation runtime estimates for TLParse."""
|
||||
def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None:
|
||||
"""Log per-op runtime estimates and output tensor metadata for TLParse."""
|
||||
|
||||
ops = [
|
||||
{
|
||||
"name": getattr(s.node, "python_kernel_name", s.get_name()),
|
||||
"type": "collective" if utils.is_collective(s.node) else "compute",
|
||||
"estimated_runtime_ns": runtime_ns,
|
||||
}
|
||||
for s, runtime_ns in node_runtimes
|
||||
]
|
||||
try:
|
||||
to_size_hints = V.graph.sizevars.size_hints
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_tlparse_runtime",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: {"ops": ops},
|
||||
)
|
||||
def to_list(x: Optional[Sequence[Any]]) -> list[Any]:
|
||||
return list(to_size_hints(x)) if x is not None else []
|
||||
|
||||
def dtype_to_str(dtype: Any) -> Optional[str]:
|
||||
if dtype is None:
|
||||
return None
|
||||
s = str(dtype)
|
||||
s = s.removeprefix("torch.")
|
||||
return s
|
||||
|
||||
ops: list[dict[str, Any]] = []
|
||||
for s, runtime_ns in node_runtimes:
|
||||
name = getattr(s.node, "python_kernel_name", s.get_name())
|
||||
op_type = "collective" if utils.is_collective(s.node) else "compute"
|
||||
|
||||
# Build outputs metadata if available
|
||||
outputs: list[dict[str, Any]] = []
|
||||
try:
|
||||
for buf in s.get_outputs():
|
||||
irnode = buf.node
|
||||
shape = irnode.maybe_get_size()
|
||||
stride = (
|
||||
irnode.get_stride()
|
||||
if isinstance(irnode.layout, ir.Layout)
|
||||
else None
|
||||
)
|
||||
dtype = irnode.maybe_get_dtype()
|
||||
outputs.append(
|
||||
{
|
||||
"shape": to_list(shape),
|
||||
"stride": to_list(stride),
|
||||
"dtype": dtype_to_str(dtype),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
ops.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": op_type,
|
||||
"estimated_runtime_ns": runtime_ns,
|
||||
"outputs": outputs,
|
||||
}
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "inductor_runtime_and_tensor_meta",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: {"ops": ops},
|
||||
)
|
||||
except Exception:
|
||||
log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user