Also put unbacked symbols in symbol_to_node in split_module pass (#130535)

This is not a complete fix but it is a simple one, full fix tracked
in https://github.com/pytorch/pytorch/issues/130534

Internal xref:
https://fb.workplace.com/groups/6829516587176185/posts/7510238679103969/

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130535
Approved by: https://github.com/malfet
This commit is contained in:
Edward Z. Yang 2024-07-11 07:38:18 -07:00 committed by PyTorch MergeBot
parent ca2d424c6e
commit 0099e15b47
2 changed files with 151 additions and 12 deletions

View File

@ -265,10 +265,9 @@ class CheckSplitsCompiler:
# other important thing is patching _active_ddp_module, which is what actually
# triggers DDP optimization
class FakeDDP(nn.Module):
def __init__(self, module):
def __init__(self, module, bucket_cap_mb=25):
super().__init__()
self.module = module
bucket_cap_mb = 25
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
@contextmanager
@ -351,6 +350,98 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_direct(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, u1 = y.tolist()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * u0
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_indirect(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, u1 = y.tolist()
a = torch.ones(u0)
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * a.sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_torture_multi(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
self.weight3 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
# partition one (contains the u0 def)
u0, u1 = y.tolist()
x = torch.cat([x, x])
y1 = x @ self.weight1
# partition two (contains the variable)
y2 = y1 @ self.weight2
a = torch.ones(u0)
# partition three
z = (x + y2 @ self.weight3) * a.sum()
return z
model = Model()
model = FakeDDP(model, bucket_cap_mb=1)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534"
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
def test_unbacked_symbol_splitting_no_binding(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
nz = y.nonzero()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * (nz + 1).sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
@patch.object(config, "optimize_ddp", True)
def test_call_method_forward(self):
class Model(nn.Module):

View File

@ -8,10 +8,11 @@ import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx._utils import lazy_format_graph_code
__all__ = ["Partition", "split_module"]
_LOGGER = logging.getLogger(__name__)
log = _LOGGER = logging.getLogger(__name__)
@compatibility(is_backward_compatible=True)
class Partition:
@ -143,6 +144,13 @@ def split_module(
True
"""
log.debug(
"%s",
lazy_format_graph_code(
"pre split_module", m, colored=True
),
)
def construct_graph(
node: Node,
base_mod_env: Dict[str, Node],
@ -184,6 +192,12 @@ def split_module(
defined = getattr(def_node, "_fx_partition", None)
used = getattr(use_node, "_fx_partition", None)
log.debug(
"record_cross_partition_use %s (%s) %s (%s)",
def_node.name, defined, use_node.name if use_node is not None else "-", used
)
if defined != used:
if defined is not None:
def_partition = partitions[defined]
@ -194,14 +208,33 @@ def split_module(
if used is not None:
use_partition = partitions[used]
use_partition.inputs.setdefault(def_node.name)
# We have made def_node an input to the use_partition. If
# this input has symbolic symbols in its size, those also must
# be made as inputs to the partition
if (def_val := def_node.meta.get("example_value")) is not None:
for s in sorted(free_symbols(def_val), key=str):
use_partition.inputs.setdefault(symbol_to_node[s].name)
s_node = symbol_to_node[s]
use_partition.inputs.setdefault(s_node.name)
if symbol_to_node[s].op != "placeholder":
# If the node that defines the symbol is not a
# placeholder, we must make it an output of the
# partition. Note that this may be in a different
# partition than defined! Although, this doesn't
# really make a difference for correctness, since
# defined is guaranteed to have the symbol in
# scope and can return it; you just get less
# optimal codegen in this case.
s_defined = getattr(s_node, "_fx_partition", None)
if s_defined is not None:
s_def_partition = partitions[s_defined]
s_def_partition.outputs.setdefault(s_node.name)
s_def_partition.dependents.setdefault(used)
if defined is not None:
use_partition.dependencies.setdefault(defined)
def instantiate_node_partition_mapping(node):
partition_name = str(split_callback(node))
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
# add node to partitions
partition = partitions.get(partition_name)
@ -239,14 +272,22 @@ def split_module(
active_autocasts = set()
for node in m.graph.nodes:
if node.op in ["placeholder", "get_attr", "output"]:
# This will prefer placeholder bindings, because those come first.
# This is a little dangerous though: it is possible that an unbacked
# symbol is used without any binding site for it, in which case we
# will get a KeyError not able to find it. I'd like to fix this by
# having passes.runtime_assert establish some invariants that I can
# rely on later, but this needs some extra work. Quick fix first.
# See https://github.com/pytorch/pytorch/issues/130534
if (
node.op == "placeholder" and
(val := node.meta.get("example_value")) is not None and
isinstance(val, torch.SymInt) and
isinstance(val.node.expr, sympy.Symbol)
isinstance(s0 := val.node.expr, sympy.Symbol) and
s0 not in symbol_to_node
):
symbol_to_node[val.node.expr] = node
if node.op in ["placeholder", "get_attr", "output"]:
continue
instantiate_node_partition_mapping(node)
@ -524,4 +565,11 @@ def split_module(
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
log.debug(
"%s",
lazy_format_graph_code(
"post split_module", ret, colored=True
),
)
return ret