From 0099e15b47bccc498d085fc6e5da32c1b5239cec Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 11 Jul 2024 07:38:18 -0700 Subject: [PATCH] Also put unbacked symbols in symbol_to_node in split_module pass (#130535) This is not a complete fix but it is a simple one, full fix tracked in https://github.com/pytorch/pytorch/issues/130534 Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7510238679103969/ Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/130535 Approved by: https://github.com/malfet --- test/distributed/test_dynamo_distributed.py | 95 ++++++++++++++++++++- torch/fx/passes/split_module.py | 68 ++++++++++++--- 2 files changed, 151 insertions(+), 12 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 4525c5b1c35..3b6c8dab671 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -265,10 +265,9 @@ class CheckSplitsCompiler: # other important thing is patching _active_ddp_module, which is what actually # triggers DDP optimization class FakeDDP(nn.Module): - def __init__(self, module): + def __init__(self, module, bucket_cap_mb=25): super().__init__() self.module = module - bucket_cap_mb = 25 self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) @contextmanager @@ -351,6 +350,98 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512)) + @config.patch(optimize_ddp=True, capture_scalar_outputs=True) + def test_unbacked_symbol_splitting_direct(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.weight1 = nn.Parameter(torch.randn(512, 512)) + self.weight2 = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x, y): + u0, u1 = y.tolist() + x = torch.cat([x, x]) + y = x @ self.weight1 + z = (x + y @ self.weight2) * u0 + return z + + model = Model() + model = FakeDDP(model) + + opt_model = torch.compile(dynamic=True)(model) + opt_model(torch.randn(20, 512), torch.tensor([12, 13])) + + @config.patch(optimize_ddp=True, capture_scalar_outputs=True) + def test_unbacked_symbol_splitting_indirect(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.weight1 = nn.Parameter(torch.randn(512, 512)) + self.weight2 = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x, y): + u0, u1 = y.tolist() + a = torch.ones(u0) + x = torch.cat([x, x]) + y = x @ self.weight1 + z = (x + y @ self.weight2) * a.sum() + return z + + model = Model() + model = FakeDDP(model) + + opt_model = torch.compile(dynamic=True)(model) + opt_model(torch.randn(20, 512), torch.tensor([12, 13])) + + @config.patch(optimize_ddp=True, capture_scalar_outputs=True) + def test_unbacked_symbol_splitting_torture_multi(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.weight1 = nn.Parameter(torch.randn(512, 512)) + self.weight2 = nn.Parameter(torch.randn(512, 512)) + self.weight3 = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x, y): + # partition one (contains the u0 def) + u0, u1 = y.tolist() + x = torch.cat([x, x]) + y1 = x @ self.weight1 + # partition two (contains the variable) + y2 = y1 @ self.weight2 + a = torch.ones(u0) + # partition three + z = (x + y2 @ self.weight3) * a.sum() + return z + + model = Model() + model = FakeDDP(model, bucket_cap_mb=1) + + opt_model = torch.compile(dynamic=True)(model) + opt_model(torch.randn(20, 512), torch.tensor([12, 13])) + + @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534" + @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) + def test_unbacked_symbol_splitting_no_binding(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.weight1 = nn.Parameter(torch.randn(512, 512)) + self.weight2 = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x, y): + nz = y.nonzero() + x = torch.cat([x, x]) + y = x @ self.weight1 + z = (x + y @ self.weight2) * (nz + 1).sum() + return z + + model = Model() + model = FakeDDP(model) + + opt_model = torch.compile(dynamic=True)(model) + opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0])) + @patch.object(config, "optimize_ddp", True) def test_call_method_forward(self): class Model(nn.Module): diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 093d7e4071d..7682129f7ab 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -8,10 +8,11 @@ import torch from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule from torch.fx.node import Node +from torch.fx._utils import lazy_format_graph_code __all__ = ["Partition", "split_module"] -_LOGGER = logging.getLogger(__name__) +log = _LOGGER = logging.getLogger(__name__) @compatibility(is_backward_compatible=True) class Partition: @@ -143,6 +144,13 @@ def split_module( True """ + log.debug( + "%s", + lazy_format_graph_code( + "pre split_module", m, colored=True + ), + ) + def construct_graph( node: Node, base_mod_env: Dict[str, Node], @@ -184,6 +192,12 @@ def split_module( defined = getattr(def_node, "_fx_partition", None) used = getattr(use_node, "_fx_partition", None) + + log.debug( + "record_cross_partition_use %s (%s) %s (%s)", + def_node.name, defined, use_node.name if use_node is not None else "-", used + ) + if defined != used: if defined is not None: def_partition = partitions[defined] @@ -194,14 +208,33 @@ def split_module( if used is not None: use_partition = partitions[used] use_partition.inputs.setdefault(def_node.name) + # We have made def_node an input to the use_partition. If + # this input has symbolic symbols in its size, those also must + # be made as inputs to the partition if (def_val := def_node.meta.get("example_value")) is not None: for s in sorted(free_symbols(def_val), key=str): - use_partition.inputs.setdefault(symbol_to_node[s].name) + s_node = symbol_to_node[s] + use_partition.inputs.setdefault(s_node.name) + if symbol_to_node[s].op != "placeholder": + # If the node that defines the symbol is not a + # placeholder, we must make it an output of the + # partition. Note that this may be in a different + # partition than defined! Although, this doesn't + # really make a difference for correctness, since + # defined is guaranteed to have the symbol in + # scope and can return it; you just get less + # optimal codegen in this case. + s_defined = getattr(s_node, "_fx_partition", None) + if s_defined is not None: + s_def_partition = partitions[s_defined] + s_def_partition.outputs.setdefault(s_node.name) + s_def_partition.dependents.setdefault(used) if defined is not None: use_partition.dependencies.setdefault(defined) def instantiate_node_partition_mapping(node): partition_name = str(split_callback(node)) + log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name) # add node to partitions partition = partitions.get(partition_name) @@ -239,14 +272,22 @@ def split_module( active_autocasts = set() for node in m.graph.nodes: + # This will prefer placeholder bindings, because those come first. + # This is a little dangerous though: it is possible that an unbacked + # symbol is used without any binding site for it, in which case we + # will get a KeyError not able to find it. I'd like to fix this by + # having passes.runtime_assert establish some invariants that I can + # rely on later, but this needs some extra work. Quick fix first. + # See https://github.com/pytorch/pytorch/issues/130534 + if ( + (val := node.meta.get("example_value")) is not None and + isinstance(val, torch.SymInt) and + isinstance(s0 := val.node.expr, sympy.Symbol) and + s0 not in symbol_to_node + ): + symbol_to_node[val.node.expr] = node + if node.op in ["placeholder", "get_attr", "output"]: - if ( - node.op == "placeholder" and - (val := node.meta.get("example_value")) is not None and - isinstance(val, torch.SymInt) and - isinstance(val.node.expr, sympy.Symbol) - ): - symbol_to_node[val.node.expr] = node continue instantiate_node_partition_mapping(node) @@ -524,4 +565,11 @@ def split_module( torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) ) # noqa: B950 - return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + log.debug( + "%s", + lazy_format_graph_code( + "post split_module", ret, colored=True + ), + ) + return ret