diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 360a13dd711..cb0776302dd 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -657,9 +657,18 @@ def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]): yield node - def prehook(grad_output): + def fmt(t): + # Avoid circular import + from torch.testing._internal.common_utils import dtype_abbrs + + if t is None: + return "None" + return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" + + def prehook(grad_outputs): node = torch._C._current_autograd_node() - log_str = f"Executing: {node} with grad_output: {grad_output}" + grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" + log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" log.debug(log_str) handles = []