mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Also put unbacked symbols in symbol_to_node in split_module pass (#130535)
This is not a complete fix but it is a simple one, full fix tracked in https://github.com/pytorch/pytorch/issues/130534 Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7510238679103969/ Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130535 Approved by: https://github.com/malfet
This commit is contained in:
parent
ca2d424c6e
commit
0099e15b47
|
|
@ -265,10 +265,9 @@ class CheckSplitsCompiler:
|
||||||
# other important thing is patching _active_ddp_module, which is what actually
|
# other important thing is patching _active_ddp_module, which is what actually
|
||||||
# triggers DDP optimization
|
# triggers DDP optimization
|
||||||
class FakeDDP(nn.Module):
|
class FakeDDP(nn.Module):
|
||||||
def __init__(self, module):
|
def __init__(self, module, bucket_cap_mb=25):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
bucket_cap_mb = 25
|
|
||||||
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -351,6 +350,98 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||||
opt_model = torch.compile(dynamic=True)(model)
|
opt_model = torch.compile(dynamic=True)(model)
|
||||||
opt_model(torch.randn(20, 512))
|
opt_model(torch.randn(20, 512))
|
||||||
|
|
||||||
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||||
|
def test_unbacked_symbol_splitting_direct(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
u0, u1 = y.tolist()
|
||||||
|
x = torch.cat([x, x])
|
||||||
|
y = x @ self.weight1
|
||||||
|
z = (x + y @ self.weight2) * u0
|
||||||
|
return z
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
model = FakeDDP(model)
|
||||||
|
|
||||||
|
opt_model = torch.compile(dynamic=True)(model)
|
||||||
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||||
|
|
||||||
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||||
|
def test_unbacked_symbol_splitting_indirect(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
u0, u1 = y.tolist()
|
||||||
|
a = torch.ones(u0)
|
||||||
|
x = torch.cat([x, x])
|
||||||
|
y = x @ self.weight1
|
||||||
|
z = (x + y @ self.weight2) * a.sum()
|
||||||
|
return z
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
model = FakeDDP(model)
|
||||||
|
|
||||||
|
opt_model = torch.compile(dynamic=True)(model)
|
||||||
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||||
|
|
||||||
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
||||||
|
def test_unbacked_symbol_splitting_torture_multi(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
self.weight3 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
# partition one (contains the u0 def)
|
||||||
|
u0, u1 = y.tolist()
|
||||||
|
x = torch.cat([x, x])
|
||||||
|
y1 = x @ self.weight1
|
||||||
|
# partition two (contains the variable)
|
||||||
|
y2 = y1 @ self.weight2
|
||||||
|
a = torch.ones(u0)
|
||||||
|
# partition three
|
||||||
|
z = (x + y2 @ self.weight3) * a.sum()
|
||||||
|
return z
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
model = FakeDDP(model, bucket_cap_mb=1)
|
||||||
|
|
||||||
|
opt_model = torch.compile(dynamic=True)(model)
|
||||||
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||||
|
|
||||||
|
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534"
|
||||||
|
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
|
||||||
|
def test_unbacked_symbol_splitting_no_binding(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
nz = y.nonzero()
|
||||||
|
x = torch.cat([x, x])
|
||||||
|
y = x @ self.weight1
|
||||||
|
z = (x + y @ self.weight2) * (nz + 1).sum()
|
||||||
|
return z
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
model = FakeDDP(model)
|
||||||
|
|
||||||
|
opt_model = torch.compile(dynamic=True)(model)
|
||||||
|
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
|
||||||
|
|
||||||
@patch.object(config, "optimize_ddp", True)
|
@patch.object(config, "optimize_ddp", True)
|
||||||
def test_call_method_forward(self):
|
def test_call_method_forward(self):
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,11 @@ import torch
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
from torch.fx._utils import lazy_format_graph_code
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Partition", "split_module"]
|
__all__ = ["Partition", "split_module"]
|
||||||
_LOGGER = logging.getLogger(__name__)
|
log = _LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class Partition:
|
class Partition:
|
||||||
|
|
@ -143,6 +144,13 @@ def split_module(
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
"%s",
|
||||||
|
lazy_format_graph_code(
|
||||||
|
"pre split_module", m, colored=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def construct_graph(
|
def construct_graph(
|
||||||
node: Node,
|
node: Node,
|
||||||
base_mod_env: Dict[str, Node],
|
base_mod_env: Dict[str, Node],
|
||||||
|
|
@ -184,6 +192,12 @@ def split_module(
|
||||||
|
|
||||||
defined = getattr(def_node, "_fx_partition", None)
|
defined = getattr(def_node, "_fx_partition", None)
|
||||||
used = getattr(use_node, "_fx_partition", None)
|
used = getattr(use_node, "_fx_partition", None)
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
"record_cross_partition_use %s (%s) %s (%s)",
|
||||||
|
def_node.name, defined, use_node.name if use_node is not None else "-", used
|
||||||
|
)
|
||||||
|
|
||||||
if defined != used:
|
if defined != used:
|
||||||
if defined is not None:
|
if defined is not None:
|
||||||
def_partition = partitions[defined]
|
def_partition = partitions[defined]
|
||||||
|
|
@ -194,14 +208,33 @@ def split_module(
|
||||||
if used is not None:
|
if used is not None:
|
||||||
use_partition = partitions[used]
|
use_partition = partitions[used]
|
||||||
use_partition.inputs.setdefault(def_node.name)
|
use_partition.inputs.setdefault(def_node.name)
|
||||||
|
# We have made def_node an input to the use_partition. If
|
||||||
|
# this input has symbolic symbols in its size, those also must
|
||||||
|
# be made as inputs to the partition
|
||||||
if (def_val := def_node.meta.get("example_value")) is not None:
|
if (def_val := def_node.meta.get("example_value")) is not None:
|
||||||
for s in sorted(free_symbols(def_val), key=str):
|
for s in sorted(free_symbols(def_val), key=str):
|
||||||
use_partition.inputs.setdefault(symbol_to_node[s].name)
|
s_node = symbol_to_node[s]
|
||||||
|
use_partition.inputs.setdefault(s_node.name)
|
||||||
|
if symbol_to_node[s].op != "placeholder":
|
||||||
|
# If the node that defines the symbol is not a
|
||||||
|
# placeholder, we must make it an output of the
|
||||||
|
# partition. Note that this may be in a different
|
||||||
|
# partition than defined! Although, this doesn't
|
||||||
|
# really make a difference for correctness, since
|
||||||
|
# defined is guaranteed to have the symbol in
|
||||||
|
# scope and can return it; you just get less
|
||||||
|
# optimal codegen in this case.
|
||||||
|
s_defined = getattr(s_node, "_fx_partition", None)
|
||||||
|
if s_defined is not None:
|
||||||
|
s_def_partition = partitions[s_defined]
|
||||||
|
s_def_partition.outputs.setdefault(s_node.name)
|
||||||
|
s_def_partition.dependents.setdefault(used)
|
||||||
if defined is not None:
|
if defined is not None:
|
||||||
use_partition.dependencies.setdefault(defined)
|
use_partition.dependencies.setdefault(defined)
|
||||||
|
|
||||||
def instantiate_node_partition_mapping(node):
|
def instantiate_node_partition_mapping(node):
|
||||||
partition_name = str(split_callback(node))
|
partition_name = str(split_callback(node))
|
||||||
|
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
|
||||||
|
|
||||||
# add node to partitions
|
# add node to partitions
|
||||||
partition = partitions.get(partition_name)
|
partition = partitions.get(partition_name)
|
||||||
|
|
@ -239,14 +272,22 @@ def split_module(
|
||||||
active_autocasts = set()
|
active_autocasts = set()
|
||||||
|
|
||||||
for node in m.graph.nodes:
|
for node in m.graph.nodes:
|
||||||
|
# This will prefer placeholder bindings, because those come first.
|
||||||
|
# This is a little dangerous though: it is possible that an unbacked
|
||||||
|
# symbol is used without any binding site for it, in which case we
|
||||||
|
# will get a KeyError not able to find it. I'd like to fix this by
|
||||||
|
# having passes.runtime_assert establish some invariants that I can
|
||||||
|
# rely on later, but this needs some extra work. Quick fix first.
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/130534
|
||||||
|
if (
|
||||||
|
(val := node.meta.get("example_value")) is not None and
|
||||||
|
isinstance(val, torch.SymInt) and
|
||||||
|
isinstance(s0 := val.node.expr, sympy.Symbol) and
|
||||||
|
s0 not in symbol_to_node
|
||||||
|
):
|
||||||
|
symbol_to_node[val.node.expr] = node
|
||||||
|
|
||||||
if node.op in ["placeholder", "get_attr", "output"]:
|
if node.op in ["placeholder", "get_attr", "output"]:
|
||||||
if (
|
|
||||||
node.op == "placeholder" and
|
|
||||||
(val := node.meta.get("example_value")) is not None and
|
|
||||||
isinstance(val, torch.SymInt) and
|
|
||||||
isinstance(val.node.expr, sympy.Symbol)
|
|
||||||
):
|
|
||||||
symbol_to_node[val.node.expr] = node
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
instantiate_node_partition_mapping(node)
|
instantiate_node_partition_mapping(node)
|
||||||
|
|
@ -524,4 +565,11 @@ def split_module(
|
||||||
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
|
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
|
||||||
) # noqa: B950
|
) # noqa: B950
|
||||||
|
|
||||||
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||||
|
log.debug(
|
||||||
|
"%s",
|
||||||
|
lazy_format_graph_code(
|
||||||
|
"post split_module", ret, colored=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user