Output tensor meta data for FX graph node (#159311)

FX graph segment in CompiledFxGraph does not include tensor meta data, for example, tensor shape, tensor stride, tensor data type, tensor device. AI system co-design team requested to include these information in FX graph segment so they can use FX graph segment to project the performance on different hardware.
This DIFF is to modify the Graph::Node::format_node to include tensor meta data.
Before this DIFF, the triton kernel FX graph segment looks like the following:
```
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
After this DIFF:
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
```
If format_node can not be changed, I can copy the code to caffe2/torch/_inductor/utils.py.

Differential Revision: D77973076

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159311
Approved by: https://github.com/angelayi
This commit is contained in:
Sheng Fu 2025-08-01 21:40:25 +00:00 committed by PyTorch MergeBot
parent 595a65f5c2
commit 0450f05658
5 changed files with 63 additions and 17 deletions

View File

@ -52,7 +52,7 @@ torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple
torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule
torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None
torch.fx.node.Node.append(self, x: 'Node') -> None
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str]
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
torch.fx.node.Node.prepend(self, x: 'Node') -> None
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]

View File

@ -14532,11 +14532,11 @@ if RUN_GPU:
else:
self.assertTrue("Graph fragment" in code)
self.assertTrue(
"%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default]"
'%sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]'
in code
)
self.assertTrue(
"%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default]"
'%relu : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]'
in code
)

View File

@ -425,6 +425,11 @@ class TestExecutionTrace(TestCase):
)
@skipCPUIf(True, "skip CPU device for testing profiling triton")
def test_execution_trace_env_enabled_with_pt2(self, device):
# clean up the local cache for triton kernel
from torch._inductor.codecache import PyCodeCache as PyCodeCache
PyCodeCache.cache_clear(purge=True)
import os
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
@ -439,7 +444,9 @@ class TestExecutionTrace(TestCase):
a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3))
inputs = [a, b, c]
with torch._inductor.config.patch(compile_threads=1):
with torch._inductor.config.patch(
compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False
):
fn(*inputs)
with profile(
@ -480,10 +487,12 @@ class TestExecutionTrace(TestCase):
)
@skipCPUIf(True, "skip CPU device for testing profiling triton")
def test_triton_fx_graph_with_et(self, device):
import os
# clean up the local cache for triton kernel
from torch._inductor.codecache import PyCodeCache as PyCodeCache
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1"
PyCodeCache.cache_clear(purge=True)
import os
@torchdynamo.optimize("inductor")
def fn(a, b, c):
@ -503,12 +512,18 @@ class TestExecutionTrace(TestCase):
):
fn(*inputs)
fp = tempfile.NamedTemporaryFile("w+t", suffix="fx_graph_et.json", delete=False)
fp.close()
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et.set_extra_resource_collection(True)
with profile(
activities=torch.profiler.supported_activities(),
record_shapes=True,
schedule=torch.profiler.schedule(
skip_first=0, wait=1, warmup=1, active=1, repeat=1
),
execution_trace_observer=et,
) as p:
for idx in range(10):
with record_function(f"## LOOP {idx} ##"):
@ -550,23 +565,23 @@ class TestExecutionTrace(TestCase):
)
assert (
fx_graph[2]
== "# %sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})" # noqa: B950
== '# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950
)
assert (
fx_graph[3]
== "# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})" # noqa: B950
== '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950
)
assert (
fx_graph[4]
== "# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})" # noqa: B950
== '# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950
)
assert (
fx_graph[5]
== "# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})" # noqa: B950
== '# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950
)
assert (
fx_graph[6]
== "# %cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})" # noqa: B950
== '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950
)
assert fx_graph[7] == "# return %cos"

View File

@ -856,7 +856,9 @@ def get_kernel_metadata(
all_writes.append("%" + output_name)
for node in inductor_nodes:
detailed_metadata.append(f"{wrapper.comment} {node.format_node()}")
detailed_metadata.append(
f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
)
detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")

View File

@ -4,7 +4,7 @@ import inspect
import logging
import operator
import types
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec
@ -15,6 +15,7 @@ from torch.fx.operator_schemas import (
normalize_function,
normalize_module,
)
from torch.utils._dtype_abbrs import dtype_abbrs
from .._ops import ops as _ops
from ._compatibility import compatibility
@ -597,6 +598,8 @@ class Node(_NodeBase):
self,
placeholder_names: Optional[list[str]] = None,
maybe_return_typename: Optional[list[str]] = None,
*,
include_tensor_metadata: bool = False,
) -> Optional[str]:
"""
Return a descriptive string representation of ``self``.
@ -618,6 +621,7 @@ class Node(_NodeBase):
maybe_return_typename: A single-element list that will store
a formatted string representing the output of the
generated ``forward`` function. Internal use only.
include_tensor_metadata: Whether to include tensor metadata
Returns:
str: If 1) we're using ``format_node`` as an internal helper
@ -649,11 +653,36 @@ class Node(_NodeBase):
maybe_return_typename[0] = f" -> {_type_repr(self.type)}"
return f"return {self.args[0]}"
else:
maybe_typename = (
f"{_type_repr(self.type)} " if self.type is not None else ""
def stringify_shape(shape: Iterable) -> str:
return f"[{', '.join([str(x) for x in shape])}]"
meta_val = self.meta.get(
"val",
self.meta.get("tensor_meta", self.meta.get("example_value", None)),
)
type_annotation = ""
if (
include_tensor_metadata
and isinstance(meta_val, torch.Tensor)
and meta_val.layout
not in (
torch.sparse_csc,
torch.sparse_csr,
)
):
stride_annotation = f"{stringify_shape(meta_val.stride())}"
device_annotation = f"{meta_val.device}"
type_annotation = (
f'Tensor "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
f'{stride_annotation}{device_annotation}"'
)
else:
type_annotation = (
f"{_type_repr(self.type)} " if self.type is not None else ""
)
return (
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
f"%{self.name} : {type_annotation}[num_users={len(self.users)}] = "
f"{self.op}[target={self._pretty_print_target(self.target)}]("
f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})"
)