[dynamo] local_map error message for reordered inputs (#164780)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164780
Approved by: https://github.com/mlazos
This commit is contained in:
Simon Fan 2025-10-26 16:32:39 -07:00 committed by PyTorch MergeBot
parent 74336f8c77
commit a76b59cc45
2 changed files with 45 additions and 0 deletions

View File

@ -556,6 +556,32 @@ class GraphModule(torch.nn.Module):
):
torch.compile(mismatch_outputs, backend="eager", fullgraph=True)(x)
@unittest.skipIf(*get_skip_reasons())
def test_local_map_dynamo_reordered_inputs(self):
@local_map(
out_placements=((Shard(0), Shard(0)),),
in_placements=(
(Shard(0), Shard(0)),
(Replicate(), Shard(0)),
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=self.mesh,
)
def reorder_inputs(first_input, second_input):
return second_input.sum() * 10 + first_input # dynamo will reorder inputs
x = torch.randn(64, 64, 64, requires_grad=True)
y = torch.randn(8, 64, 64, requires_grad=True)
with (
LocalMapWrappedHigherOrderVariable.enable(),
self.assertRaisesRegex(
AssertionError,
r"Dynamo changed the order of inputs to the local_map function, please adjust the order of inputs and input_placements from \[l_args_0_, l_args_1_\], to: \[l_args_1_, l_args_0_\].*",
),
):
torch.compile(reorder_inputs, backend="eager", fullgraph=True)(x, y)
@unittest.skipIf(*get_skip_reasons())
def test_local_map_with_local_shapes_hop_tracing(self):
def fn(x):

View File

@ -3711,6 +3711,23 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
make_error_msg(expected_num_outputs, actual_num_outputs, "outputs")
)
if inputs_none_placements > 0:
expected_input_nodes = [
arg.as_proxy().node for arg in user_args[:-inputs_none_placements]
]
else:
expected_input_nodes = [arg.as_proxy().node for arg in user_args]
actual_input_nodes = [proxy.node for proxy in p_args]
assert actual_input_nodes[0].op == "get_attr"
assert "subgraph" in actual_input_nodes[0].target
assert len(expected_input_nodes) == len(actual_input_nodes) - 1
for expected_order, actual_order in zip(
expected_input_nodes, actual_input_nodes[1:]
):
assert expected_order == actual_order, (
"Dynamo changed the order of inputs to the local_map function, please adjust "
f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}"
)
assert len(p_kwargs) == 0
flat_example_value = pytree.tree_map_only(
@ -3752,6 +3769,8 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
vt.as_proxy().node.meta["example_value"] = global_tensor
vt.synchronize_attributes(tx)
# TODO: Figure out how to handle output order diverging from eager
# Treat as const, so we don't have to deal with Placement types in fx IR
# Guarded with EQUALS_MATCH on local_map call's arguments
body_gmod.meta["local_map_kwargs"] = {