mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix DDPOptimizer issue on static tensor index (#155746)
We rely on `_try_get_metadata_from_dynamo()` to get static input indices. When the meta info is missing, it just returns an empty list of static input indices. This wrong list of static input indices lead to repeated cudagraph re-recording, which looks like a hang from the user perspective.bc3972b80a/torch/_functorch/aot_autograd.py (L1025-L1031)The root cause is `split_module` in DDP Optimizer loses meta info and gm attributes. This PR fixes the issue by propagating these metadata from original module to submodules.bc3972b80a/torch/_dynamo/backends/distributed.py (L515-L517)Fixes #140395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155746 Approved by: https://github.com/xmfan, https://github.com/bdhirsh
This commit is contained in:
parent
3b6569b1ef
commit
38410cf9b5
|
|
@ -668,6 +668,50 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
outputs = fsdp_m(inputs)
|
outputs = fsdp_m(inputs)
|
||||||
self.assertTrue(same(correct_outputs, outputs))
|
self.assertTrue(same(correct_outputs, outputs))
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
|
def test_ddp_optimizer_cudagraph(self):
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# need a large channel to trigger ddp optimizer split module
|
||||||
|
self.CHANNELS = 640
|
||||||
|
self.convi = nn.Conv2d(46, self.CHANNELS, 3, padding=1, bias=False)
|
||||||
|
self.convp = nn.Conv2d(
|
||||||
|
self.CHANNELS, self.CHANNELS, 1, padding=0, bias=False
|
||||||
|
)
|
||||||
|
self.bni = nn.BatchNorm2d(self.CHANNELS)
|
||||||
|
|
||||||
|
def forward(self, bitmap_channels):
|
||||||
|
x = self.convi(bitmap_channels)
|
||||||
|
x = self.bni(x)
|
||||||
|
x = self.convp(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||||
|
net = Net().to(self.rank)
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
net.parameters(),
|
||||||
|
lr=5e-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
net = DDP(net, device_ids=[self.rank])
|
||||||
|
opt_net = torch.compile(net, mode="reduce-overhead")
|
||||||
|
opt_net.train()
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
data = torch.randn((16, 46, 8, 8), dtype=torch.float32, device="cuda")
|
||||||
|
opt_net(data).sum().backward()
|
||||||
|
|
||||||
|
# 2 fwd and 2 bwd graph such that 4 graphs in total
|
||||||
|
graph_id = (
|
||||||
|
torch._inductor.cudagraph_trees.get_container(self.rank)
|
||||||
|
.tree_manager.new_graph_id()
|
||||||
|
.id
|
||||||
|
)
|
||||||
|
self.assertTrue(graph_id == 4)
|
||||||
|
|
||||||
@config.patch(enable_compiler_collectives=True)
|
@config.patch(enable_compiler_collectives=True)
|
||||||
@skip_if_lt_x_gpu(1)
|
@skip_if_lt_x_gpu(1)
|
||||||
def test_fsdp_setattr(self):
|
def test_fsdp_setattr(self):
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,26 @@ def has_higher_order_op(gm):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def propagate_metadata(orig_gm, split_gm) -> None:
|
||||||
|
for name, module in split_gm.named_modules():
|
||||||
|
if "." not in name and len(name):
|
||||||
|
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
|
||||||
|
module.meta = orig_gm.meta
|
||||||
|
module._param_name_to_source = orig_gm._param_name_to_source
|
||||||
|
|
||||||
|
|
||||||
|
def propagate_dynamo_source(orig_gm, split_gm) -> None:
|
||||||
|
name_to_dynamo_source = {}
|
||||||
|
for node in orig_gm.graph.find_nodes(op="placeholder"):
|
||||||
|
name_to_dynamo_source[node.name] = node._dynamo_source
|
||||||
|
|
||||||
|
for name, module in split_gm.named_modules():
|
||||||
|
if "." not in name and len(name):
|
||||||
|
for node in module.graph.find_nodes(op="placeholder"):
|
||||||
|
# non-placeholder in original_gm may become placeholder in submodules
|
||||||
|
node._dynamo_source = name_to_dynamo_source.get(node.name, None)
|
||||||
|
|
||||||
|
|
||||||
# compile each of the partitioned submodules using the user-provided compiler
|
# compile each of the partitioned submodules using the user-provided compiler
|
||||||
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
||||||
def __init__(self, module, compiler, fake_mode) -> None:
|
def __init__(self, module, compiler, fake_mode) -> None:
|
||||||
|
|
@ -516,6 +536,10 @@ class DDPOptimizer:
|
||||||
gm, None, lambda node: partition_map[node]
|
gm, None, lambda node: partition_map[node]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# See note [Assumption on Dynamo Metadata]
|
||||||
|
propagate_dynamo_source(gm, split_gm)
|
||||||
|
propagate_metadata(gm, split_gm)
|
||||||
|
|
||||||
debug_str = (
|
debug_str = (
|
||||||
f"\n---orig graph---\n{gm.graph}\n"
|
f"\n---orig graph---\n{gm.graph}\n"
|
||||||
+ f"\n---split graph---\n{split_gm.graph}\n"
|
+ f"\n---split graph---\n{split_gm.graph}\n"
|
||||||
|
|
|
||||||
|
|
@ -1022,6 +1022,12 @@ def _try_get_metadata_from_dynamo(
|
||||||
aot_autograd_arg_pos_to_source: used to dedup params and their guards
|
aot_autograd_arg_pos_to_source: used to dedup params and their guards
|
||||||
static_input_indices: used to identify static inputs for cudagraphs
|
static_input_indices: used to identify static inputs for cudagraphs
|
||||||
"""
|
"""
|
||||||
|
# Note [Assumption on Dynamo Metadata]
|
||||||
|
# This function assumes a graph module from dynamo provides `dynamo_compiled_id`,
|
||||||
|
# _param_name_to_source, and every placeholder node has `_dynamo_source` attributes.
|
||||||
|
# When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to
|
||||||
|
# be propagated in order to be recognized as a dynamo graph
|
||||||
|
|
||||||
if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
|
if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta):
|
||||||
# graph was not captured by dynamo
|
# graph was not captured by dynamo
|
||||||
return None, []
|
return None, []
|
||||||
|
|
@ -1055,7 +1061,10 @@ def _try_get_metadata_from_dynamo(
|
||||||
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
|
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
|
||||||
assert hasattr(node, "_dynamo_source")
|
assert hasattr(node, "_dynamo_source")
|
||||||
source = node._dynamo_source
|
source = node._dynamo_source
|
||||||
assert source not in seen_sources, source
|
# `source`` specifies the source from user code. ddp optimizer may have
|
||||||
|
# intermediate values becoming submodule placeholders which does not
|
||||||
|
# have a source
|
||||||
|
assert source is None or source not in seen_sources, source
|
||||||
seen_sources.add(source)
|
seen_sources.add(source)
|
||||||
aot_autograd_arg_pos_to_source.append(source)
|
aot_autograd_arg_pos_to_source.append(source)
|
||||||
source_name = source.name() if source else str(source)
|
source_name = source.name() if source else str(source)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user