pytorch/torch/_inductor/comms_debug.py
PyTorch MergeBot 1dd6b76914 Revert "[1/N] Remove unused loop variables (#166258)"
This reverts commit 76b2c37045.

Reverted https://github.com/pytorch/pytorch/pull/166258 on behalf of https://github.com/atalman due to breaks test/distributed/test_serialization.py::TestSerialization::test_weights_only [GH job link](https://github.com/pytorch/pytorch/actions/runs/18894311802/job/53929321703) [HUD commit link](76b2c37045) ([comment](https://github.com/pytorch/pytorch/pull/166258#issuecomment-3460964612))
2025-10-29 11:10:37 +00:00

113 lines
4.1 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error