mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] local_map error message for reordered inputs (#164780)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164780 Approved by: https://github.com/mlazos
This commit is contained in:
parent
74336f8c77
commit
a76b59cc45
|
|
@ -556,6 +556,32 @@ class GraphModule(torch.nn.Module):
|
||||||
):
|
):
|
||||||
torch.compile(mismatch_outputs, backend="eager", fullgraph=True)(x)
|
torch.compile(mismatch_outputs, backend="eager", fullgraph=True)(x)
|
||||||
|
|
||||||
|
@unittest.skipIf(*get_skip_reasons())
|
||||||
|
def test_local_map_dynamo_reordered_inputs(self):
|
||||||
|
@local_map(
|
||||||
|
out_placements=((Shard(0), Shard(0)),),
|
||||||
|
in_placements=(
|
||||||
|
(Shard(0), Shard(0)),
|
||||||
|
(Replicate(), Shard(0)),
|
||||||
|
),
|
||||||
|
redistribute_inputs=True,
|
||||||
|
in_grad_placements=None,
|
||||||
|
device_mesh=self.mesh,
|
||||||
|
)
|
||||||
|
def reorder_inputs(first_input, second_input):
|
||||||
|
return second_input.sum() * 10 + first_input # dynamo will reorder inputs
|
||||||
|
|
||||||
|
x = torch.randn(64, 64, 64, requires_grad=True)
|
||||||
|
y = torch.randn(8, 64, 64, requires_grad=True)
|
||||||
|
with (
|
||||||
|
LocalMapWrappedHigherOrderVariable.enable(),
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
AssertionError,
|
||||||
|
r"Dynamo changed the order of inputs to the local_map function, please adjust the order of inputs and input_placements from \[l_args_0_, l_args_1_\], to: \[l_args_1_, l_args_0_\].*",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch.compile(reorder_inputs, backend="eager", fullgraph=True)(x, y)
|
||||||
|
|
||||||
@unittest.skipIf(*get_skip_reasons())
|
@unittest.skipIf(*get_skip_reasons())
|
||||||
def test_local_map_with_local_shapes_hop_tracing(self):
|
def test_local_map_with_local_shapes_hop_tracing(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
|
|
|
||||||
|
|
@ -3711,6 +3711,23 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||||
make_error_msg(expected_num_outputs, actual_num_outputs, "outputs")
|
make_error_msg(expected_num_outputs, actual_num_outputs, "outputs")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if inputs_none_placements > 0:
|
||||||
|
expected_input_nodes = [
|
||||||
|
arg.as_proxy().node for arg in user_args[:-inputs_none_placements]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
expected_input_nodes = [arg.as_proxy().node for arg in user_args]
|
||||||
|
actual_input_nodes = [proxy.node for proxy in p_args]
|
||||||
|
assert actual_input_nodes[0].op == "get_attr"
|
||||||
|
assert "subgraph" in actual_input_nodes[0].target
|
||||||
|
assert len(expected_input_nodes) == len(actual_input_nodes) - 1
|
||||||
|
for expected_order, actual_order in zip(
|
||||||
|
expected_input_nodes, actual_input_nodes[1:]
|
||||||
|
):
|
||||||
|
assert expected_order == actual_order, (
|
||||||
|
"Dynamo changed the order of inputs to the local_map function, please adjust "
|
||||||
|
f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}"
|
||||||
|
)
|
||||||
assert len(p_kwargs) == 0
|
assert len(p_kwargs) == 0
|
||||||
|
|
||||||
flat_example_value = pytree.tree_map_only(
|
flat_example_value = pytree.tree_map_only(
|
||||||
|
|
@ -3752,6 +3769,8 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||||
vt.as_proxy().node.meta["example_value"] = global_tensor
|
vt.as_proxy().node.meta["example_value"] = global_tensor
|
||||||
vt.synchronize_attributes(tx)
|
vt.synchronize_attributes(tx)
|
||||||
|
|
||||||
|
# TODO: Figure out how to handle output order diverging from eager
|
||||||
|
|
||||||
# Treat as const, so we don't have to deal with Placement types in fx IR
|
# Treat as const, so we don't have to deal with Placement types in fx IR
|
||||||
# Guarded with EQUALS_MATCH on local_map call's arguments
|
# Guarded with EQUALS_MATCH on local_map call's arguments
|
||||||
body_gmod.meta["local_map_kwargs"] = {
|
body_gmod.meta["local_map_kwargs"] = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user