mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] local_map error message for reordered inputs (#164780)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164780 Approved by: https://github.com/mlazos
This commit is contained in:
parent
74336f8c77
commit
a76b59cc45
|
|
@ -556,6 +556,32 @@ class GraphModule(torch.nn.Module):
|
|||
):
|
||||
torch.compile(mismatch_outputs, backend="eager", fullgraph=True)(x)
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_local_map_dynamo_reordered_inputs(self):
|
||||
@local_map(
|
||||
out_placements=((Shard(0), Shard(0)),),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(0)),
|
||||
(Replicate(), Shard(0)),
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=self.mesh,
|
||||
)
|
||||
def reorder_inputs(first_input, second_input):
|
||||
return second_input.sum() * 10 + first_input # dynamo will reorder inputs
|
||||
|
||||
x = torch.randn(64, 64, 64, requires_grad=True)
|
||||
y = torch.randn(8, 64, 64, requires_grad=True)
|
||||
with (
|
||||
LocalMapWrappedHigherOrderVariable.enable(),
|
||||
self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
r"Dynamo changed the order of inputs to the local_map function, please adjust the order of inputs and input_placements from \[l_args_0_, l_args_1_\], to: \[l_args_1_, l_args_0_\].*",
|
||||
),
|
||||
):
|
||||
torch.compile(reorder_inputs, backend="eager", fullgraph=True)(x, y)
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_local_map_with_local_shapes_hop_tracing(self):
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -3711,6 +3711,23 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
|||
make_error_msg(expected_num_outputs, actual_num_outputs, "outputs")
|
||||
)
|
||||
|
||||
if inputs_none_placements > 0:
|
||||
expected_input_nodes = [
|
||||
arg.as_proxy().node for arg in user_args[:-inputs_none_placements]
|
||||
]
|
||||
else:
|
||||
expected_input_nodes = [arg.as_proxy().node for arg in user_args]
|
||||
actual_input_nodes = [proxy.node for proxy in p_args]
|
||||
assert actual_input_nodes[0].op == "get_attr"
|
||||
assert "subgraph" in actual_input_nodes[0].target
|
||||
assert len(expected_input_nodes) == len(actual_input_nodes) - 1
|
||||
for expected_order, actual_order in zip(
|
||||
expected_input_nodes, actual_input_nodes[1:]
|
||||
):
|
||||
assert expected_order == actual_order, (
|
||||
"Dynamo changed the order of inputs to the local_map function, please adjust "
|
||||
f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}"
|
||||
)
|
||||
assert len(p_kwargs) == 0
|
||||
|
||||
flat_example_value = pytree.tree_map_only(
|
||||
|
|
@ -3752,6 +3769,8 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
|||
vt.as_proxy().node.meta["example_value"] = global_tensor
|
||||
vt.synchronize_attributes(tx)
|
||||
|
||||
# TODO: Figure out how to handle output order diverging from eager
|
||||
|
||||
# Treat as const, so we don't have to deal with Placement types in fx IR
|
||||
# Guarded with EQUALS_MATCH on local_map call's arguments
|
||||
body_gmod.meta["local_map_kwargs"] = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user