mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Augment DebugMode to support attributes reporting (#165109)
DebugMode reports tensor type, it shapes and placements while active. This change augments reporting to tensor attributes from configured set. This feature is intended to be used to ease understanding debug string when dealing with larger outputs. For example, before running forward pass of a model we can annotate each of parameters and buffers with their fully qualified names, so that we can see which ops are being executed against specific tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165109 Approved by: https://github.com/ezyang, https://github.com/pianpwk
This commit is contained in:
parent
f363114852
commit
1e35b3c4e0
|
|
@ -215,6 +215,29 @@ class TestDTensorDebugMode(TestCase):
|
|||
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
def test_tensor_attributes(self):
|
||||
x = torch.randn(8, 8)
|
||||
x.a1 = "x1"
|
||||
x.a2 = "x2"
|
||||
y = torch.randn(8, 8, 8)
|
||||
y.a1 = "y"
|
||||
|
||||
with DebugMode(
|
||||
record_torchfunction=True,
|
||||
record_faketensor=True,
|
||||
record_tensor_attributes=["a1", "a2"],
|
||||
) as debug_mode:
|
||||
torch.matmul(y, x)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.matmul(t: f32[8, 8, 8]{a1=y}, t: f32[8, 8]{a1=x1, a2=x2})
|
||||
aten::view(t: f32[8, 8, 8]{a1=y}, [64, 8])
|
||||
aten::mm(t: f32[64, 8], t: f32[8, 8]{a1=x1, a2=x2})
|
||||
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
@parametrize("has_inner_mode", [True, False])
|
||||
@parametrize("has_outer_mode", [True, False])
|
||||
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
|
||||
|
|
|
|||
|
|
@ -30,25 +30,39 @@ def _stringify_placement(placement) -> str:
|
|||
return f"[{', '.join([str(p) for p in placement])}]"
|
||||
|
||||
|
||||
def _tensor_debug_string(tensor) -> str:
|
||||
def _stringify_attributes(tensor, attributes) -> str:
|
||||
pairs = {}
|
||||
for attr in attributes:
|
||||
if hasattr(tensor, attr):
|
||||
pairs[attr] = getattr(tensor, attr)
|
||||
if len(pairs) == 0:
|
||||
return ""
|
||||
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
|
||||
|
||||
|
||||
def _tensor_debug_string(tensor, attributes) -> str:
|
||||
"""Convert tensor to debug string representation."""
|
||||
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
||||
# omitted device mesh
|
||||
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
|
||||
elif isinstance(tensor, FakeTensor):
|
||||
return f"ft: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
return f"t: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
|
||||
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
|
||||
|
||||
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
||||
# omitted device mesh
|
||||
return f"dt: {tensor_debug_str}{_stringify_placement(tensor.placements)}"
|
||||
elif isinstance(tensor, FakeTensor):
|
||||
return f"ft: {tensor_debug_str}"
|
||||
else:
|
||||
return f"t: {tensor_debug_str}"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
|
||||
|
||||
|
||||
def _arg_to_str(arg) -> str:
|
||||
def _arg_to_str(arg, attributes) -> str:
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
|
||||
def to_str(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return _tensor_debug_string(x)
|
||||
return _tensor_debug_string(x, attributes)
|
||||
elif isinstance(x, DTensorSpec):
|
||||
return _stringify_placement(x.placements)
|
||||
return x
|
||||
|
|
@ -57,17 +71,17 @@ def _arg_to_str(arg) -> str:
|
|||
return str(arg)
|
||||
|
||||
|
||||
def _op_to_str(op, *args, **kwargs) -> str:
|
||||
def _op_to_str(op, attributes, *args, **kwargs) -> str:
|
||||
if op == REDISTRIBUTE_FUNC:
|
||||
assert len(args) == 3
|
||||
_args = [_arg_to_str(arg) for arg in args]
|
||||
_args = [_arg_to_str(arg, attributes) for arg in args]
|
||||
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
|
||||
else:
|
||||
args_str = ", ".join(_arg_to_str(arg) for arg in args)
|
||||
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in args)
|
||||
|
||||
if kwargs:
|
||||
kwargs_str = ", " + ", ".join(
|
||||
f"{k}={_arg_to_str(v)}" for k, v in kwargs.items()
|
||||
f"{k}={_arg_to_str(v, attributes)}" for k, v in kwargs.items()
|
||||
)
|
||||
else:
|
||||
kwargs_str = ""
|
||||
|
|
@ -89,6 +103,7 @@ class DebugMode(TorchDispatchMode):
|
|||
record_torchfunction=False,
|
||||
record_faketensor=False,
|
||||
record_realtensor=True,
|
||||
record_tensor_attributes=None,
|
||||
):
|
||||
super().__init__()
|
||||
import torch.distributed.tensor # noqa: F401
|
||||
|
|
@ -97,6 +112,7 @@ class DebugMode(TorchDispatchMode):
|
|||
self.record_torchfunction = record_torchfunction
|
||||
self.record_faketensor = record_faketensor
|
||||
self.record_realtensor = record_realtensor
|
||||
self.record_tensor_attributes = record_tensor_attributes or []
|
||||
|
||||
self.operators = []
|
||||
self.call_depth = 0
|
||||
|
|
@ -178,7 +194,9 @@ class DebugMode(TorchDispatchMode):
|
|||
with torch._C.DisableTorchFunction():
|
||||
result = ""
|
||||
result += "\n".join(
|
||||
" " + " " * depth + _op_to_str(op, *args, **kwargs)
|
||||
" "
|
||||
+ " " * depth
|
||||
+ _op_to_str(op, self.record_tensor_attributes, *args, **kwargs)
|
||||
for op, args, kwargs, depth in self.operators
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user