Augment DebugMode to support attributes reporting (#165109)

DebugMode reports tensor type, it shapes and placements while active. This change augments reporting to tensor attributes from configured set. This feature is intended to be used to ease understanding debug string when dealing with larger outputs. For example, before running forward pass of a model we can annotate each of parameters and buffers with their fully qualified names, so that we can see which ops are being executed against specific tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165109
Approved by: https://github.com/ezyang, https://github.com/pianpwk
This commit is contained in:
Dzmitry Huba 2025-10-10 21:27:01 +00:00 committed by PyTorch MergeBot
parent f363114852
commit 1e35b3c4e0
2 changed files with 56 additions and 15 deletions

View File

@ -215,6 +215,29 @@ class TestDTensorDebugMode(TestCase):
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
)
def test_tensor_attributes(self):
x = torch.randn(8, 8)
x.a1 = "x1"
x.a2 = "x2"
y = torch.randn(8, 8, 8)
y.a1 = "y"
with DebugMode(
record_torchfunction=True,
record_faketensor=True,
record_tensor_attributes=["a1", "a2"],
) as debug_mode:
torch.matmul(y, x)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.matmul(t: f32[8, 8, 8]{a1=y}, t: f32[8, 8]{a1=x1, a2=x2})
aten::view(t: f32[8, 8, 8]{a1=y}, [64, 8])
aten::mm(t: f32[64, 8], t: f32[8, 8]{a1=x1, a2=x2})
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
)
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):

View File

@ -30,25 +30,39 @@ def _stringify_placement(placement) -> str:
return f"[{', '.join([str(p) for p in placement])}]"
def _tensor_debug_string(tensor) -> str:
def _stringify_attributes(tensor, attributes) -> str:
pairs = {}
for attr in attributes:
if hasattr(tensor, attr):
pairs[attr] = getattr(tensor, attr)
if len(pairs) == 0:
return ""
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
def _tensor_debug_string(tensor, attributes) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
elif isinstance(tensor, FakeTensor):
return f"ft: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
elif isinstance(tensor, torch.Tensor):
return f"t: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
if isinstance(tensor, torch.Tensor):
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {tensor_debug_str}{_stringify_placement(tensor.placements)}"
elif isinstance(tensor, FakeTensor):
return f"ft: {tensor_debug_str}"
else:
return f"t: {tensor_debug_str}"
else:
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
def _arg_to_str(arg) -> str:
def _arg_to_str(arg, attributes) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x)
return _tensor_debug_string(x, attributes)
elif isinstance(x, DTensorSpec):
return _stringify_placement(x.placements)
return x
@ -57,17 +71,17 @@ def _arg_to_str(arg) -> str:
return str(arg)
def _op_to_str(op, *args, **kwargs) -> str:
def _op_to_str(op, attributes, *args, **kwargs) -> str:
if op == REDISTRIBUTE_FUNC:
assert len(args) == 3
_args = [_arg_to_str(arg) for arg in args]
_args = [_arg_to_str(arg, attributes) for arg in args]
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
else:
args_str = ", ".join(_arg_to_str(arg) for arg in args)
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in args)
if kwargs:
kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v)}" for k, v in kwargs.items()
f"{k}={_arg_to_str(v, attributes)}" for k, v in kwargs.items()
)
else:
kwargs_str = ""
@ -89,6 +103,7 @@ class DebugMode(TorchDispatchMode):
record_torchfunction=False,
record_faketensor=False,
record_realtensor=True,
record_tensor_attributes=None,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
@ -97,6 +112,7 @@ class DebugMode(TorchDispatchMode):
self.record_torchfunction = record_torchfunction
self.record_faketensor = record_faketensor
self.record_realtensor = record_realtensor
self.record_tensor_attributes = record_tensor_attributes or []
self.operators = []
self.call_depth = 0
@ -178,7 +194,9 @@ class DebugMode(TorchDispatchMode):
with torch._C.DisableTorchFunction():
result = ""
result += "\n".join(
" " + " " * depth + _op_to_str(op, *args, **kwargs)
" "
+ " " * depth
+ _op_to_str(op, self.record_tensor_attributes, *args, **kwargs)
for op, args, kwargs, depth in self.operators
)
return result