mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Also put unbacked symbols in symbol_to_node in split_module pass (#130535)
This is not a complete fix but it is a simple one, full fix tracked in https://github.com/pytorch/pytorch/issues/130534 Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7510238679103969/ Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130535 Approved by: https://github.com/malfet
This commit is contained in:
parent
ca2d424c6e
commit
0099e15b47
|
|
@ -265,10 +265,9 @@ class CheckSplitsCompiler:
|
|||
# other important thing is patching _active_ddp_module, which is what actually
|
||||
# triggers DDP optimization
|
||||
class FakeDDP(nn.Module):
|
||||
def __init__(self, module):
|
||||
def __init__(self, module, bucket_cap_mb=25):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
bucket_cap_mb = 25
|
||||
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -351,6 +350,98 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
|||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512))
|
||||
|
||||
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||
def test_unbacked_symbol_splitting_direct(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||
|
||||
def forward(self, x, y):
|
||||
u0, u1 = y.tolist()
|
||||
x = torch.cat([x, x])
|
||||
y = x @ self.weight1
|
||||
z = (x + y @ self.weight2) * u0
|
||||
return z
|
||||
|
||||
model = Model()
|
||||
model = FakeDDP(model)
|
||||
|
||||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||
|
||||
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||
def test_unbacked_symbol_splitting_indirect(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||
|
||||
def forward(self, x, y):
|
||||
u0, u1 = y.tolist()
|
||||
a = torch.ones(u0)
|
||||
x = torch.cat([x, x])
|
||||
y = x @ self.weight1
|
||||
z = (x + y @ self.weight2) * a.sum()
|
||||
return z
|
||||
|
||||
model = Model()
|
||||
model = FakeDDP(model)
|
||||
|
||||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||
|
||||
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||
def test_unbacked_symbol_splitting_torture_multi(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||
self.weight3 = nn.Parameter(torch.randn(512, 512))
|
||||
|
||||
def forward(self, x, y):
|
||||
# partition one (contains the u0 def)
|
||||
u0, u1 = y.tolist()
|
||||
x = torch.cat([x, x])
|
||||
y1 = x @ self.weight1
|
||||
# partition two (contains the variable)
|
||||
y2 = y1 @ self.weight2
|
||||
a = torch.ones(u0)
|
||||
# partition three
|
||||
z = (x + y2 @ self.weight3) * a.sum()
|
||||
return z
|
||||
|
||||
model = Model()
|
||||
model = FakeDDP(model, bucket_cap_mb=1)
|
||||
|
||||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||
|
||||
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534"
|
||||
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
|
||||
def test_unbacked_symbol_splitting_no_binding(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||
|
||||
def forward(self, x, y):
|
||||
nz = y.nonzero()
|
||||
x = torch.cat([x, x])
|
||||
y = x @ self.weight1
|
||||
z = (x + y @ self.weight2) * (nz + 1).sum()
|
||||
return z
|
||||
|
||||
model = Model()
|
||||
model = FakeDDP(model)
|
||||
|
||||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
|
||||
|
||||
@patch.object(config, "optimize_ddp", True)
|
||||
def test_call_method_forward(self):
|
||||
class Model(nn.Module):
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ import torch
|
|||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
|
||||
__all__ = ["Partition", "split_module"]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
log = _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Partition:
|
||||
|
|
@ -143,6 +144,13 @@ def split_module(
|
|||
True
|
||||
"""
|
||||
|
||||
log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"pre split_module", m, colored=True
|
||||
),
|
||||
)
|
||||
|
||||
def construct_graph(
|
||||
node: Node,
|
||||
base_mod_env: Dict[str, Node],
|
||||
|
|
@ -184,6 +192,12 @@ def split_module(
|
|||
|
||||
defined = getattr(def_node, "_fx_partition", None)
|
||||
used = getattr(use_node, "_fx_partition", None)
|
||||
|
||||
log.debug(
|
||||
"record_cross_partition_use %s (%s) %s (%s)",
|
||||
def_node.name, defined, use_node.name if use_node is not None else "-", used
|
||||
)
|
||||
|
||||
if defined != used:
|
||||
if defined is not None:
|
||||
def_partition = partitions[defined]
|
||||
|
|
@ -194,14 +208,33 @@ def split_module(
|
|||
if used is not None:
|
||||
use_partition = partitions[used]
|
||||
use_partition.inputs.setdefault(def_node.name)
|
||||
# We have made def_node an input to the use_partition. If
|
||||
# this input has symbolic symbols in its size, those also must
|
||||
# be made as inputs to the partition
|
||||
if (def_val := def_node.meta.get("example_value")) is not None:
|
||||
for s in sorted(free_symbols(def_val), key=str):
|
||||
use_partition.inputs.setdefault(symbol_to_node[s].name)
|
||||
s_node = symbol_to_node[s]
|
||||
use_partition.inputs.setdefault(s_node.name)
|
||||
if symbol_to_node[s].op != "placeholder":
|
||||
# If the node that defines the symbol is not a
|
||||
# placeholder, we must make it an output of the
|
||||
# partition. Note that this may be in a different
|
||||
# partition than defined! Although, this doesn't
|
||||
# really make a difference for correctness, since
|
||||
# defined is guaranteed to have the symbol in
|
||||
# scope and can return it; you just get less
|
||||
# optimal codegen in this case.
|
||||
s_defined = getattr(s_node, "_fx_partition", None)
|
||||
if s_defined is not None:
|
||||
s_def_partition = partitions[s_defined]
|
||||
s_def_partition.outputs.setdefault(s_node.name)
|
||||
s_def_partition.dependents.setdefault(used)
|
||||
if defined is not None:
|
||||
use_partition.dependencies.setdefault(defined)
|
||||
|
||||
def instantiate_node_partition_mapping(node):
|
||||
partition_name = str(split_callback(node))
|
||||
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
|
||||
|
||||
# add node to partitions
|
||||
partition = partitions.get(partition_name)
|
||||
|
|
@ -239,14 +272,22 @@ def split_module(
|
|||
active_autocasts = set()
|
||||
|
||||
for node in m.graph.nodes:
|
||||
# This will prefer placeholder bindings, because those come first.
|
||||
# This is a little dangerous though: it is possible that an unbacked
|
||||
# symbol is used without any binding site for it, in which case we
|
||||
# will get a KeyError not able to find it. I'd like to fix this by
|
||||
# having passes.runtime_assert establish some invariants that I can
|
||||
# rely on later, but this needs some extra work. Quick fix first.
|
||||
# See https://github.com/pytorch/pytorch/issues/130534
|
||||
if (
|
||||
(val := node.meta.get("example_value")) is not None and
|
||||
isinstance(val, torch.SymInt) and
|
||||
isinstance(s0 := val.node.expr, sympy.Symbol) and
|
||||
s0 not in symbol_to_node
|
||||
):
|
||||
symbol_to_node[val.node.expr] = node
|
||||
|
||||
if node.op in ["placeholder", "get_attr", "output"]:
|
||||
if (
|
||||
node.op == "placeholder" and
|
||||
(val := node.meta.get("example_value")) is not None and
|
||||
isinstance(val, torch.SymInt) and
|
||||
isinstance(val.node.expr, sympy.Symbol)
|
||||
):
|
||||
symbol_to_node[val.node.expr] = node
|
||||
continue
|
||||
|
||||
instantiate_node_partition_mapping(node)
|
||||
|
|
@ -524,4 +565,11 @@ def split_module(
|
|||
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
|
||||
) # noqa: B950
|
||||
|
||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"post split_module", ret, colored=True
|
||||
),
|
||||
)
|
||||
return ret
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user