diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index f4e85f01e09..603a823d7be 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -556,6 +556,32 @@ class GraphModule(torch.nn.Module): ): torch.compile(mismatch_outputs, backend="eager", fullgraph=True)(x) + @unittest.skipIf(*get_skip_reasons()) + def test_local_map_dynamo_reordered_inputs(self): + @local_map( + out_placements=((Shard(0), Shard(0)),), + in_placements=( + (Shard(0), Shard(0)), + (Replicate(), Shard(0)), + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=self.mesh, + ) + def reorder_inputs(first_input, second_input): + return second_input.sum() * 10 + first_input # dynamo will reorder inputs + + x = torch.randn(64, 64, 64, requires_grad=True) + y = torch.randn(8, 64, 64, requires_grad=True) + with ( + LocalMapWrappedHigherOrderVariable.enable(), + self.assertRaisesRegex( + AssertionError, + r"Dynamo changed the order of inputs to the local_map function, please adjust the order of inputs and input_placements from \[l_args_0_, l_args_1_\], to: \[l_args_1_, l_args_0_\].*", + ), + ): + torch.compile(reorder_inputs, backend="eager", fullgraph=True)(x, y) + @unittest.skipIf(*get_skip_reasons()) def test_local_map_with_local_shapes_hop_tracing(self): def fn(x): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 18fc80f3898..7676e2def57 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3711,6 +3711,23 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): make_error_msg(expected_num_outputs, actual_num_outputs, "outputs") ) + if inputs_none_placements > 0: + expected_input_nodes = [ + arg.as_proxy().node for arg in user_args[:-inputs_none_placements] + ] + else: + expected_input_nodes = [arg.as_proxy().node for arg in user_args] + actual_input_nodes = [proxy.node for proxy in p_args] + assert actual_input_nodes[0].op == "get_attr" + assert "subgraph" in actual_input_nodes[0].target + assert len(expected_input_nodes) == len(actual_input_nodes) - 1 + for expected_order, actual_order in zip( + expected_input_nodes, actual_input_nodes[1:] + ): + assert expected_order == actual_order, ( + "Dynamo changed the order of inputs to the local_map function, please adjust " + f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}" + ) assert len(p_kwargs) == 0 flat_example_value = pytree.tree_map_only( @@ -3752,6 +3769,8 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): vt.as_proxy().node.meta["example_value"] = global_tensor vt.synchronize_attributes(tx) + # TODO: Figure out how to handle output order diverging from eager + # Treat as const, so we don't have to deal with Placement types in fx IR # Guarded with EQUALS_MATCH on local_map call's arguments body_gmod.meta["local_map_kwargs"] = {