mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
DebugMode reports tensor type, it shapes and placements while active. This change augments reporting to tensor attributes from configured set. This feature is intended to be used to ease understanding debug string when dealing with larger outputs. For example, before running forward pass of a model we can annotate each of parameters and buffers with their fully qualified names, so that we can see which ops are being executed against specific tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165109 Approved by: https://github.com/ezyang, https://github.com/pianpwk
212 lines
6.4 KiB
Python
212 lines
6.4 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.utils._dtype_abbrs import dtype_abbrs
|
|
from torch.utils._python_dispatch import (
|
|
_get_current_dispatch_mode,
|
|
_get_current_dispatch_mode_stack,
|
|
TorchDispatchMode,
|
|
)
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
__all__ = ["DebugMode", "get_active_debug_mode"]
|
|
|
|
REDISTRIBUTE_FUNC = "redistribute_input"
|
|
|
|
|
|
def _stringify_shape(shape) -> str:
|
|
return f"[{', '.join([str(x) for x in shape])}]"
|
|
|
|
|
|
def _stringify_device_mesh(mesh) -> str:
|
|
return f"DM({', '.join([str(s) for s in mesh.shape])})"
|
|
|
|
|
|
def _stringify_placement(placement) -> str:
|
|
return f"[{', '.join([str(p) for p in placement])}]"
|
|
|
|
|
|
def _stringify_attributes(tensor, attributes) -> str:
|
|
pairs = {}
|
|
for attr in attributes:
|
|
if hasattr(tensor, attr):
|
|
pairs[attr] = getattr(tensor, attr)
|
|
if len(pairs) == 0:
|
|
return ""
|
|
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
|
|
|
|
|
|
def _tensor_debug_string(tensor, attributes) -> str:
|
|
"""Convert tensor to debug string representation."""
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
|
|
|
|
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
|
# omitted device mesh
|
|
return f"dt: {tensor_debug_str}{_stringify_placement(tensor.placements)}"
|
|
elif isinstance(tensor, FakeTensor):
|
|
return f"ft: {tensor_debug_str}"
|
|
else:
|
|
return f"t: {tensor_debug_str}"
|
|
else:
|
|
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
|
|
|
|
|
|
def _arg_to_str(arg, attributes) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
def to_str(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return _tensor_debug_string(x, attributes)
|
|
elif isinstance(x, DTensorSpec):
|
|
return _stringify_placement(x.placements)
|
|
return x
|
|
|
|
arg = tree_map(to_str, arg)
|
|
return str(arg)
|
|
|
|
|
|
def _op_to_str(op, attributes, *args, **kwargs) -> str:
|
|
if op == REDISTRIBUTE_FUNC:
|
|
assert len(args) == 3
|
|
_args = [_arg_to_str(arg, attributes) for arg in args]
|
|
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
|
|
else:
|
|
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in args)
|
|
|
|
if kwargs:
|
|
kwargs_str = ", " + ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes)}" for k, v in kwargs.items()
|
|
)
|
|
else:
|
|
kwargs_str = ""
|
|
|
|
if isinstance(op, torch._ops.OpOverload):
|
|
op_name = op.__qualname__
|
|
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
|
|
op_name = f"{op.__module__}.{op.__name__}"
|
|
else:
|
|
op_name = str(op)
|
|
|
|
return f"{op_name}({args_str}{kwargs_str})"
|
|
|
|
|
|
class DebugMode(TorchDispatchMode):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
record_torchfunction=False,
|
|
record_faketensor=False,
|
|
record_realtensor=True,
|
|
record_tensor_attributes=None,
|
|
):
|
|
super().__init__()
|
|
import torch.distributed.tensor # noqa: F401
|
|
|
|
self.supports_higher_order_operators = True
|
|
self.record_torchfunction = record_torchfunction
|
|
self.record_faketensor = record_faketensor
|
|
self.record_realtensor = record_realtensor
|
|
self.record_tensor_attributes = record_tensor_attributes or []
|
|
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
# Without this override, running torch.compile under DebugMode
|
|
# will force torch.compile to always use the “eager” backend
|
|
# With this, DebugMode will not take effect on torch.compile
|
|
@classmethod
|
|
def ignore_compile_internals(cls):
|
|
return True
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
self.operators.append((func, args, kwargs, self.call_depth))
|
|
|
|
try:
|
|
self.call_depth += 1
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# Record the operation with its call depth
|
|
if torch.distributed.tensor.DTensor in types:
|
|
self.operators.append((func, args, kwargs, self.call_depth))
|
|
return NotImplemented
|
|
elif FakeTensor in types or isinstance(
|
|
_get_current_dispatch_mode(), FakeTensorMode
|
|
):
|
|
if self.record_faketensor:
|
|
if func != torch.ops.prim.device.default:
|
|
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
|
elif len(types) == 0:
|
|
if self.record_realtensor:
|
|
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
return result
|
|
|
|
def __enter__(self):
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
if self.record_torchfunction:
|
|
torch._C._push_on_torch_function_stack(self)
|
|
|
|
super().__enter__()
|
|
return self
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def __exit__(self, *args):
|
|
super().__exit__(*args)
|
|
if self.record_torchfunction:
|
|
torch._C._pop_torch_function_stack()
|
|
|
|
@contextlib.contextmanager
|
|
def record_redistribute_calls(self, arg_idx, src_placement, dst_placement):
|
|
try:
|
|
self.operators.append(
|
|
(
|
|
REDISTRIBUTE_FUNC,
|
|
[arg_idx, src_placement, dst_placement],
|
|
{},
|
|
self.call_depth + 1,
|
|
)
|
|
)
|
|
self.call_depth += 1
|
|
yield
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def debug_string(self) -> str:
|
|
with torch._C.DisableTorchFunction():
|
|
result = ""
|
|
result += "\n".join(
|
|
" "
|
|
+ " " * depth
|
|
+ _op_to_str(op, self.record_tensor_attributes, *args, **kwargs)
|
|
for op, args, kwargs, depth in self.operators
|
|
)
|
|
return result
|
|
|
|
|
|
def get_active_debug_mode() -> Optional[DebugMode]:
|
|
debug_mode = None
|
|
for mode in _get_current_dispatch_mode_stack():
|
|
if isinstance(mode, DebugMode):
|
|
debug_mode = mode
|
|
break
|
|
return debug_mode
|