Fix DDPOptimizer issue on static tensor index (#155746)

We rely on `_try_get_metadata_from_dynamo()` to get static input indices. When the meta info is missing, it just returns an empty list of static input indices. This wrong list of static input indices lead to repeated cudagraph re-recording, which looks like a hang from the user perspective. bc3972b80a/torch/_functorch/aot_autograd.py (L1025-L1031)

The root cause is `split_module` in DDP Optimizer loses meta info and gm attributes. This PR fixes the issue by propagating these metadata from original module to submodules.
bc3972b80a/torch/_dynamo/backends/distributed.py (L515-L517)

Fixes #140395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155746
Approved by: https://github.com/xmfan, https://github.com/bdhirsh
This commit is contained in:
Boyuan Feng 2025-06-14 00:15:58 +00:00 committed by PyTorch MergeBot
parent 3b6569b1ef
commit 38410cf9b5
3 changed files with 78 additions and 1 deletions

View File

@ -668,6 +668,50 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_ddp_optimizer_cudagraph(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
# need a large channel to trigger ddp optimizer split module
self.CHANNELS = 640
self.convi = nn.Conv2d(46, self.CHANNELS, 3, padding=1, bias=False)
self.convp = nn.Conv2d(
self.CHANNELS, self.CHANNELS, 1, padding=0, bias=False
)
self.bni = nn.BatchNorm2d(self.CHANNELS)
def forward(self, bitmap_channels):
x = self.convi(bitmap_channels)
x = self.bni(x)
x = self.convp(x)
return x
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
net = Net().to(self.rank)
optimizer = torch.optim.SGD(
net.parameters(),
lr=5e-2,
)
net = DDP(net, device_ids=[self.rank])
opt_net = torch.compile(net, mode="reduce-overhead")
opt_net.train()
for _ in range(10):
optimizer.zero_grad()
data = torch.randn((16, 46, 8, 8), dtype=torch.float32, device="cuda")
opt_net(data).sum().backward()
# 2 fwd and 2 bwd graph such that 4 graphs in total
graph_id = (
torch._inductor.cudagraph_trees.get_container(self.rank)
.tree_manager.new_graph_id()
.id
)
self.assertTrue(graph_id == 4)
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_setattr(self):

View File

@ -146,6 +146,26 @@ def has_higher_order_op(gm):
return False
def propagate_metadata(orig_gm, split_gm) -> None:
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
module.meta = orig_gm.meta
module._param_name_to_source = orig_gm._param_name_to_source
def propagate_dynamo_source(orig_gm, split_gm) -> None:
name_to_dynamo_source = {}
for node in orig_gm.graph.find_nodes(op="placeholder"):
name_to_dynamo_source[node.name] = node._dynamo_source
for name, module in split_gm.named_modules():
if "." not in name and len(name):
for node in module.graph.find_nodes(op="placeholder"):
# non-placeholder in original_gm may become placeholder in submodules
node._dynamo_source = name_to_dynamo_source.get(node.name, None)
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode) -> None:
@ -516,6 +536,10 @@ class DDPOptimizer:
gm, None, lambda node: partition_map[node]
)
# See note [Assumption on Dynamo Metadata]
propagate_dynamo_source(gm, split_gm)
propagate_metadata(gm, split_gm)
debug_str = (
f"\n---orig graph---\n{gm.graph}\n"
+ f"\n---split graph---\n{split_gm.graph}\n"

View File

@ -1022,6 +1022,12 @@ def _try_get_metadata_from_dynamo(
aot_autograd_arg_pos_to_source: used to dedup params and their guards
static_input_indices: used to identify static inputs for cudagraphs
"""
# Note [Assumption on Dynamo Metadata]
# This function assumes a graph module from dynamo provides `dynamo_compiled_id`,
# _param_name_to_source, and every placeholder node has `_dynamo_source` attributes.
# When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to
# be propagated in order to be recognized as a dynamo graph
if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
# graph was not captured by dynamo
return None, []
@ -1055,7 +1061,10 @@ def _try_get_metadata_from_dynamo(
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
assert hasattr(node, "_dynamo_source")
source = node._dynamo_source
assert source not in seen_sources, source
# `source`` specifies the source from user code. ddp optimizer may have
# intermediate values becoming submodule placeholders which does not
# have a source
assert source is None or source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
source_name = source.name() if source else str(source)