mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Output tensor meta data for FX graph node (#159311)
FX graph segment in CompiledFxGraph does not include tensor meta data, for example, tensor shape, tensor stride, tensor data type, tensor device. AI system co-design team requested to include these information in FX graph segment so they can use FX graph segment to project the performance on different hardware.
This DIFF is to modify the Graph::Node::format_node to include tensor meta data.
Before this DIFF, the triton kernel FX graph segment looks like the following:
```
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
After this DIFF:
# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]
# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]
# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})
# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})
# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})
# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})
# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})
# return %cos
```
If format_node can not be changed, I can copy the code to caffe2/torch/_inductor/utils.py.
Differential Revision: D77973076
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159311
Approved by: https://github.com/angelayi
This commit is contained in:
parent
595a65f5c2
commit
0450f05658
|
|
@ -52,7 +52,7 @@ torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple
|
|||
torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule
|
||||
torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None
|
||||
torch.fx.node.Node.append(self, x: 'Node') -> None
|
||||
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str]
|
||||
torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str]
|
||||
torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
|
||||
torch.fx.node.Node.prepend(self, x: 'Node') -> None
|
||||
torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = <function <lambda>>, propagate_meta: bool = False) -> List[Node]
|
||||
|
|
|
|||
|
|
@ -14532,11 +14532,11 @@ if RUN_GPU:
|
|||
else:
|
||||
self.assertTrue("Graph fragment" in code)
|
||||
self.assertTrue(
|
||||
"%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default]"
|
||||
'%sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]'
|
||||
in code
|
||||
)
|
||||
self.assertTrue(
|
||||
"%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default]"
|
||||
'%relu : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]'
|
||||
in code
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -425,6 +425,11 @@ class TestExecutionTrace(TestCase):
|
|||
)
|
||||
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
||||
def test_execution_trace_env_enabled_with_pt2(self, device):
|
||||
# clean up the local cache for triton kernel
|
||||
from torch._inductor.codecache import PyCodeCache as PyCodeCache
|
||||
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
|
||||
import os
|
||||
|
||||
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
|
||||
|
|
@ -439,7 +444,9 @@ class TestExecutionTrace(TestCase):
|
|||
a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3))
|
||||
|
||||
inputs = [a, b, c]
|
||||
with torch._inductor.config.patch(compile_threads=1):
|
||||
with torch._inductor.config.patch(
|
||||
compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False
|
||||
):
|
||||
fn(*inputs)
|
||||
|
||||
with profile(
|
||||
|
|
@ -480,10 +487,12 @@ class TestExecutionTrace(TestCase):
|
|||
)
|
||||
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
||||
def test_triton_fx_graph_with_et(self, device):
|
||||
import os
|
||||
# clean up the local cache for triton kernel
|
||||
from torch._inductor.codecache import PyCodeCache as PyCodeCache
|
||||
|
||||
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1"
|
||||
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1"
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
|
||||
import os
|
||||
|
||||
@torchdynamo.optimize("inductor")
|
||||
def fn(a, b, c):
|
||||
|
|
@ -503,12 +512,18 @@ class TestExecutionTrace(TestCase):
|
|||
):
|
||||
fn(*inputs)
|
||||
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix="fx_graph_et.json", delete=False)
|
||||
fp.close()
|
||||
et = ExecutionTraceObserver()
|
||||
et.register_callback(fp.name)
|
||||
et.set_extra_resource_collection(True)
|
||||
with profile(
|
||||
activities=torch.profiler.supported_activities(),
|
||||
record_shapes=True,
|
||||
schedule=torch.profiler.schedule(
|
||||
skip_first=0, wait=1, warmup=1, active=1, repeat=1
|
||||
),
|
||||
execution_trace_observer=et,
|
||||
) as p:
|
||||
for idx in range(10):
|
||||
with record_function(f"## LOOP {idx} ##"):
|
||||
|
|
@ -550,23 +565,23 @@ class TestExecutionTrace(TestCase):
|
|||
)
|
||||
assert (
|
||||
fx_graph[2]
|
||||
== "# %sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})" # noqa: B950
|
||||
== '# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950
|
||||
)
|
||||
assert (
|
||||
fx_graph[3]
|
||||
== "# %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})" # noqa: B950
|
||||
== '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950
|
||||
)
|
||||
assert (
|
||||
fx_graph[4]
|
||||
== "# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})" # noqa: B950
|
||||
== '# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950
|
||||
)
|
||||
assert (
|
||||
fx_graph[5]
|
||||
== "# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})" # noqa: B950
|
||||
== '# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950
|
||||
)
|
||||
assert (
|
||||
fx_graph[6]
|
||||
== "# %cos : [num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})" # noqa: B950
|
||||
== '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950
|
||||
)
|
||||
assert fx_graph[7] == "# return %cos"
|
||||
|
||||
|
|
|
|||
|
|
@ -856,7 +856,9 @@ def get_kernel_metadata(
|
|||
all_writes.append("%" + output_name)
|
||||
|
||||
for node in inductor_nodes:
|
||||
detailed_metadata.append(f"{wrapper.comment} {node.format_node()}")
|
||||
detailed_metadata.append(
|
||||
f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
|
||||
)
|
||||
|
||||
detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import inspect
|
|||
import logging
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from torch.fx.operator_schemas import (
|
|||
normalize_function,
|
||||
normalize_module,
|
||||
)
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
|
||||
from .._ops import ops as _ops
|
||||
from ._compatibility import compatibility
|
||||
|
|
@ -597,6 +598,8 @@ class Node(_NodeBase):
|
|||
self,
|
||||
placeholder_names: Optional[list[str]] = None,
|
||||
maybe_return_typename: Optional[list[str]] = None,
|
||||
*,
|
||||
include_tensor_metadata: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return a descriptive string representation of ``self``.
|
||||
|
|
@ -618,6 +621,7 @@ class Node(_NodeBase):
|
|||
maybe_return_typename: A single-element list that will store
|
||||
a formatted string representing the output of the
|
||||
generated ``forward`` function. Internal use only.
|
||||
include_tensor_metadata: Whether to include tensor metadata
|
||||
|
||||
Returns:
|
||||
str: If 1) we're using ``format_node`` as an internal helper
|
||||
|
|
@ -649,11 +653,36 @@ class Node(_NodeBase):
|
|||
maybe_return_typename[0] = f" -> {_type_repr(self.type)}"
|
||||
return f"return {self.args[0]}"
|
||||
else:
|
||||
maybe_typename = (
|
||||
f"{_type_repr(self.type)} " if self.type is not None else ""
|
||||
|
||||
def stringify_shape(shape: Iterable) -> str:
|
||||
return f"[{', '.join([str(x) for x in shape])}]"
|
||||
|
||||
meta_val = self.meta.get(
|
||||
"val",
|
||||
self.meta.get("tensor_meta", self.meta.get("example_value", None)),
|
||||
)
|
||||
type_annotation = ""
|
||||
if (
|
||||
include_tensor_metadata
|
||||
and isinstance(meta_val, torch.Tensor)
|
||||
and meta_val.layout
|
||||
not in (
|
||||
torch.sparse_csc,
|
||||
torch.sparse_csr,
|
||||
)
|
||||
):
|
||||
stride_annotation = f"{stringify_shape(meta_val.stride())}"
|
||||
device_annotation = f"{meta_val.device}"
|
||||
type_annotation = (
|
||||
f'Tensor "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
|
||||
f'{stride_annotation}{device_annotation}"'
|
||||
)
|
||||
else:
|
||||
type_annotation = (
|
||||
f"{_type_repr(self.type)} " if self.type is not None else ""
|
||||
)
|
||||
return (
|
||||
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
|
||||
f"%{self.name} : {type_annotation}[num_users={len(self.users)}] = "
|
||||
f"{self.op}[target={self._pretty_print_target(self.target)}]("
|
||||
f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user