From 821458d97a89b2f3f559b0000c3ef1500b7fd7dc Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 16 Sep 2025 14:03:49 -0700 Subject: [PATCH] [dynamo][hop] Introduce Local Map HOP (#161458) Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458 Approved by: https://github.com/ydwu4 --- test/dynamo/test_higher_order_ops.py | 1 + test/higher_order_ops/test_local_map.py | 203 ++++++++++++ test/inductor/test_compiled_autograd.py | 2 +- torch/_dynamo/variables/builder.py | 7 +- torch/_dynamo/variables/higher_order_ops.py | 111 +++++++ torch/_higher_order_ops/__init__.py | 2 + torch/_higher_order_ops/local_map.py | 327 ++++++++++++++++++++ torch/distributed/tensor/_ops/_view_ops.py | 4 +- torch/testing/_internal/hop_db.py | 47 +++ 9 files changed, 701 insertions(+), 3 deletions(-) create mode 100644 test/higher_order_ops/test_local_map.py create mode 100644 torch/_higher_order_ops/local_map.py diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9f093d4dc0c..78943b41bc2 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -7197,6 +7197,7 @@ xfail_hops_compile = { # aot_eager "map", # assert type(args[1].realize()) is TensorVariable "scan", # scan is not an OpOverload + "local_map_hop", # can't retrace # inductor "while_loop", # LoweringException: AssertionError "flex_attention", # LoweringException: AssertionError diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py new file mode 100644 index 00000000000..46ecacc2b33 --- /dev/null +++ b/test/higher_order_ops/test_local_map.py @@ -0,0 +1,203 @@ +# Owner(s): ["module: higher order operators"] +# flake8: noqa: B950 + + +import unittest + +import torch +import torch._dynamo +import torch._functorch +import torch._inductor +import torch._inductor.decomposition +import torch.nn.functional as F +from torch import nn +from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable + + +if torch.distributed.is_available(): + from torch.distributed._tensor.experimental import local_map + from torch.distributed.tensor.placement_types import Replicate, Shard + +from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase +from torch.testing._internal.triton_utils import requires_cuda_and_triton + + +nested_compile_region = torch.compiler.nested_compile_region + + +class MyTransform(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + 100 + + @staticmethod + def backward(ctx, grad): + return grad + 100 + + +def context_parallel_attention(query, key, value): + out = F.scaled_dot_product_attention( + query=query, key=key, value=value, is_causal=False + ) + return out + + +def create_model(attention_fn, nheads, dim1, dim2): + class LocalMapTransformerBlock(nn.Module): + def __init__(self, nheads, dim1, dim2): + super().__init__() + self.nheads = nheads + bias = False + self.wq = nn.Linear(dim1, dim1, bias=bias) + self.wk = nn.Linear(dim1, dim1, bias=bias) + self.wv = nn.Linear(dim1, dim1, bias=bias) + self.wo = nn.Linear(dim1, dim1, bias=bias) + self.w1 = nn.Linear(dim1, dim2, bias=bias) + self.w2 = nn.Linear(dim2, dim1, bias=bias) + + def forward(self, x): + q = self.wq(x) + k = self.wk(x) + v = self.wv(x) + + q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + + o = attention_fn(q, k, v) + o = o.permute(0, 2, 1, 3).flatten(-2) + + o = self.wo(o) + + o0 = o + x + + o = self.w1(o0) + o = torch.nn.functional.relu(o) + o = self.w2(o) + + o = o0 + o + return o + + return LocalMapTransformerBlock(nheads, dim1, dim2) + + +class TestLocalMap(TestCase): + @requires_cuda_and_triton + @unittest.skipIf( + not torch.distributed.is_available(), "Torch distributed not available." + ) + def test_simple(self): + @local_map( + out_placements=((Shard(0), Shard(1), Shard(2)),), + in_placements=( + (Shard(0), Shard(1), Shard(2)), # query + (Shard(0), Shard(1), Replicate()), # key + (Shard(0), Shard(1), Replicate()), # value + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, + ) + def cp_decorated(query, key, value): + return context_parallel_attention(query, key, value) + + cp_function = local_map( + context_parallel_attention, + out_placements=(Shard(0), Shard(1), Shard(2)), + in_placements=( + (Shard(0), Shard(1), Shard(2)), # query + (Shard(0), Shard(1), Replicate()), # key + (Shard(0), Shard(1), Replicate()), # value + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, + ) + bs = 8 * 1 + dim1 = 96 + dim2 = dim1 * 4 + nheads = 16 + seq_len = 16 + + from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm + + backend = EagerAndRecordGraphs() + + model = create_model(cp_decorated, nheads, dim1, dim2).cuda() + inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),) + with LocalMapWrappedHigherOrderVariable.enable(): + out = torch.compile(model, backend=backend)(*inputs) + out.sum().backward() + + model = create_model(cp_function, nheads, dim1, dim2).cuda() + inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),) + with LocalMapWrappedHigherOrderVariable.enable(): + out = torch.compile(model, backend=backend)(*inputs) + out.sum().backward() + + if not TEST_WITH_CROSSREF: + self.assertEqual(len(backend.graphs), 2) + # should see local_map_hop in both + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_self_modules_wq_parameters_weight_: "f32[96, 96]", L_x_: "f32[8, 16, 96]", L_self_modules_wk_parameters_weight_: "f32[96, 96]", L_self_modules_wv_parameters_weight_: "f32[96, 96]", L_self_modules_wo_parameters_weight_: "f32[96, 96]", L_self_modules_w1_parameters_weight_: "f32[384, 96]", L_self_modules_w2_parameters_weight_: "f32[96, 384]"): + l_self_modules_wq_parameters_weight_ = L_self_modules_wq_parameters_weight_ + l_x_ = L_x_ + l_self_modules_wk_parameters_weight_ = L_self_modules_wk_parameters_weight_ + l_self_modules_wv_parameters_weight_ = L_self_modules_wv_parameters_weight_ + l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_ + l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_ + l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_ + + q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None + + k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None + + v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None + + unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None + q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None + + unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None + k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None + + unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None + v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None + + subgraph_0 = self.subgraph_0 + local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None + o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None + + permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None + o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None + + o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None + + o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None + + o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None + + o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None + + o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None + + o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None + return (o_6,) + + class subgraph_0(torch.nn.Module): + def forward(self, q_1: "f32[8, 16, 16, 6]", k_1: "f32[8, 16, 16, 6]", v_1: "f32[8, 16, 16, 6]"): + out: "f32[8, 16, 16, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None + return (out,) +""", + ) + + self.assertEqual( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + normalize_gm(backend.graphs[1].print_readable(print_output=False)), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 6014a6e6986..e0cd8b99a6b 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -5354,7 +5354,7 @@ if torch.distributed.is_available() and HAS_CUDA_AND_TRITON: test_dtensor.TestDTensorCompile ) -xfail_hops = {} +xfail_hops = {"local_map_hop"} class TestCompiledAutogradOpInfo(TestCase): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 20b88759ef3..660042b33b8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -206,7 +206,10 @@ from .functions import ( UserMethodVariable, WrapperUserFunctionVariable, ) -from .higher_order_ops import TorchHigherOrderOperatorVariable +from .higher_order_ops import ( + LocalMapWrappedHigherOrderVariable, + TorchHigherOrderOperatorVariable, +) from .iter import ItertoolsVariable from .lazy import LazyVariableTracker from .lists import ( @@ -850,6 +853,8 @@ class VariableBuilder: return build_checkpoint_variable(source=self.source) elif is_invoke_subgraph(value): return build_invoke_subgraph_variable(source=self.source) + elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value): + return LocalMapWrappedHigherOrderVariable.build(source=self.source) elif isinstance(value, functools.partial): func_src = AttrSource(self.get_source(), "func") func_obj = VariableBuilder(self.tx, func_src)(value.func) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7135bba36df..41a22972cba 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3383,6 +3383,7 @@ class BaseHOPVariable(WrapHigherOrderVariable): lambda a: a.node.meta["example_value"], body_r.as_proxy(), ) + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} return _call_function_and_unflatten_output( tx, self.value, p_args, p_kwargs, flat_example_value, treespec @@ -3497,6 +3498,115 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): ) +class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + # Subclasses aren't supported by speculate_subgraph yet + # So this HOP is only usable with plain tensors + _enabled = False + + @classmethod + @contextlib.contextmanager + def enable(cls): + """Context manager to temporarily enable local map wrapping. + Will be removed when speculate_subgraph supports subclass inputs: + https://github.com/pytorch/pytorch/issues/161456. + + Usage: + with LocalMapWrappedHigherOrderVariable.enable_wrapping(): + # Code where should_wrap_in_hop will return True + pass + """ + old_value = cls._enabled + cls._enabled = True + try: + yield + finally: + cls._enabled = old_value + + @classmethod + def should_wrap_in_hop(cls, value): + if not torch.distributed.is_available(): + return False + + from torch.distributed.tensor.experimental._func_map import _local_map_wrapped + + # check is important to avoid subclass dispatch + if type(value) != type(_local_map_wrapped): + return False + + return value == _local_map_wrapped and cls._enabled + + @staticmethod + def build(**options): + return TorchHigherOrderOperatorVariable.make( + torch._higher_order_ops.local_map_hop, + **options, + ) + + def python_type(self): + return type(self.value) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + """ + Goal of this function is to rewrite local_map usage as a HOP: + local_map(func, ...) -> local_map_hop(gm, ...) + """ + + ( + user_func, + out_placements, + in_placements, + in_grad_placements, + device_mesh, + redistribute_inputs, + *user_args, + ) = args + + ( + p_args, + p_kwargs, + example_value, + body_r, + treespec, + body_gmod, + body_name, + ) = self.create_wrapped_node( + tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph" + ) + + # Treat as const, so we don't have to deal with Placement types in fx IR + # Guarded with EQUALS_MATCH on local_map call's arguments + body_gmod.meta["local_map_kwargs"] = { + "out_placements": out_placements.value, + "in_placements": in_placements.value, + "redistribute_inputs": redistribute_inputs.value, + "in_grad_placements": in_grad_placements.value, + "device_mesh": device_mesh.value, + } + + assert len(p_kwargs) == 0 + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + out = _call_function_and_unflatten_output( + tx, self.value, p_args, p_kwargs, flat_example_value, treespec + ) + + return out + + # Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() _hop_name_to_variable_class = { "cond": CondHigherOrderVariable, @@ -3525,4 +3635,5 @@ _hop_name_to_variable_class = { "auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable, "invoke_subgraph": InvokeSubgraphHigherOrderVariable, "custom_function_call": CustomFunctionHigherOrderOperatorVariable, + "local_map_hop": LocalMapWrappedHigherOrderVariable, } diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index e809c729dc4..516d58bdf31 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -21,6 +21,7 @@ from torch._higher_order_ops.flex_attention import ( from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.invoke_subgraph import invoke_subgraph +from torch._higher_order_ops.local_map import local_map_hop from torch._higher_order_ops.map import map from torch._higher_order_ops.out_dtype import out_dtype from torch._higher_order_ops.run_const_graph import run_const_graph @@ -73,4 +74,5 @@ __all__ = [ "aoti_call_delegate", "map", "while_loop_stack_output", + "local_map_hop", ] diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py new file mode 100644 index 00000000000..22cb2af50f1 --- /dev/null +++ b/torch/_higher_order_ops/local_map.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# NOTE: this file may be removed once we move to a dynamo frontend + +import functools +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, Callable, Optional + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + clone_outputs_aliasing_inputs, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +# Proxy the HOP instead of inlining into it +_DEFER_INLINING = False + + +@contextmanager +def defer_inlining() -> Generator[None, None, None]: + global _DEFER_INLINING + prior = _DEFER_INLINING + try: + _DEFER_INLINING = True + yield + finally: + _DEFER_INLINING = prior + + +class LocalMapHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("local_map_hop") + + def __call__(self, fw_gm: GraphModule, *args: Any, **kwargs: Any) -> Any: + return super().__call__(fw_gm, *args, **kwargs) + + +local_map_hop = LocalMapHOP() + + +def create_hop_fw_bw( + fw_gm: GraphModule, + *_args: Any, +) -> tuple[GraphModule, GraphModule, int, int, set[int]]: + """ + Traces a joint, applies passes and partitions it + """ + # Keeping these imports here + # Avoid circular dependencies once we upstream with dynamo frontend + from torch._dispatch.python import suspend_functionalization + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._guards import detect_fake_mode + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + from torch._subclasses.functional_tensor import disable_functional_mode + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # create a tensor (fake) from a compiler wrapped FunctionalTensor + def _from_fun(t: Any) -> Any: + if isinstance(t, torch.Tensor): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + return t + + # If someone runs this hop under the default compiler backend ("eager") + # Then this path will be run with the actual user inputs. We convert them + # to fake tensors in order to not perform any actual compute. + + fake_mode = detect_fake_mode(_args) + if fake_mode is None: + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + with fake_mode: + fw_inputs = pytree.tree_map(_from_fun, _args) + + assert all( + isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs + ), f"Unexpected element in {fw_inputs=}" + + example_grads = pytree.tree_map( + _from_fun, + fw_gm(*fw_inputs), + ) + if not isinstance(example_grads, (list, tuple)): + example_grads = [example_grads] + + num_fw_inputs = len(fw_inputs) + num_fw_outputs = len(example_grads) + + def joint_f( + *primals_and_tangents: list[torch.Tensor], + ) -> Any: + primals = primals_and_tangents[:num_fw_inputs] + tangents = primals_and_tangents[num_fw_inputs:] + + def prepare_fw_with_masks(fn: Callable[..., Any]) -> Callable[..., Any]: + def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]: + fw_out = fn(*args) + assert isinstance(fw_out, tuple), ( + "Dynamo traced submodule should return tuple" + ) + return fw_out, [ + True + if isinstance(ret, torch.Tensor) and ret.requires_grad + else False + for ret in fw_out + ] + + return fw_with_masks + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fw_gm), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + # put grads first to work with existing hop utils + return pytree.tree_map(maybe_clone, (*grads, *fw_outs)) + + filtered_grads_idx = set() + for i, example_grad in enumerate(example_grads): + # Filter out grads that are None or do not require_grad. + # The AOTAutograd utils we rely on force this assumption. + # We must also filter the runtime tangents too. + if example_grad is not None and ( + isinstance(example_grad, torch.Tensor) and example_grad.requires_grad + ): + filtered_grads_idx.add(i) + + primals_and_tangents = [ + *fw_inputs, + *[example_grads[i] for i in filtered_grads_idx], + ] + joint_hop_gm = make_fx(joint_f)(*primals_and_tangents) + + from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner + from torch._inductor.compile_fx import partition_fn + + # Match partitioner convention + prepped_joint_hop_gm = prepare_for_partitioner( + joint_hop_gm, num_fw_inputs, num_fw_outputs + ) + # Also runs joint passes + new_fw_gm, new_bw_gm = partition_fn( + prepped_joint_hop_gm, + [], + num_fwd_outputs=num_fw_outputs, + static_lifetime_input_indices=[], + ) + + # Propagate meta onto fw/bw graphs, later will be set on proxied nodes + local_map_kwargs = fw_gm.meta["local_map_kwargs"] # type: ignore[attr-defined] + + new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs + new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs} + # Okay because Autoparallel assumes same sharding between param and grads + new_bw_gm.meta["local_map_kwargs"]["in_placements"] = local_map_kwargs[ + "out_placements" + ] + new_bw_gm.meta["local_map_kwargs"]["out_placements"] = local_map_kwargs[ + "in_placements" + ] + + return new_fw_gm, new_bw_gm, num_fw_inputs, num_fw_outputs, filtered_grads_idx + + +class LocalMapAutogradOp(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + fw_gm: GraphModule, + bw_gm: GraphModule, + num_fw_ins: int, + num_fw_outs: int, + filtered_grads_idx: set[int], + *args: Any, + **kwargs: Any, + ) -> tuple[Optional[torch.Tensor], ...]: + ctx.bw_gm = bw_gm + ctx.num_fw_ins = num_fw_ins + ctx.filtered_grads_idx = filtered_grads_idx + + with torch._C._AutoDispatchBelowAutograd(): + fw_outs_with_saved_activations = local_map_hop(fw_gm, *args, **kwargs) + + fw_outs = fw_outs_with_saved_activations[:num_fw_outs] + saved_activations = fw_outs_with_saved_activations[num_fw_outs:] + save_tensors_and_symints_for_backward(ctx, saved_activations) + + return fw_outs + + @staticmethod + def backward( + ctx: Any, *_grads: tuple[torch.Tensor] + ) -> tuple[Optional[torch.Tensor], ...]: + saved_activations = saved_tensors_and_symints(ctx) + with torch._C._AutoDispatchBelowAutograd(): + # Filter out grads that are None or do not require_grad. + # The AOTAutograd utils we rely on force this assumption. + grads = [_grads[i] for i in ctx.filtered_grads_idx] + grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads) + if len(grad_ins) != ctx.num_fw_ins: + raise RuntimeError( + f"Expected {ctx.num_fw_ins} grad_ins, got {len(grad_ins)}" + ) + return None, None, None, None, None, *grad_ins + + +@local_map_hop.py_impl(torch._C.DispatchKey.Autograd) +def autograd_key( + fw_gm: GraphModule, + *args: Any, + **kwargs: Any, +) -> Any: + if _DEFER_INLINING: + fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx = create_hop_fw_bw( + fw_gm, *args + ) + return LocalMapAutogradOp.apply( + fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs + ) + + return fw_gm(*args, **kwargs) + + +@local_map_hop.py_functionalize_impl +def functional_mode_key( + ctx: Any, fw_gm: GraphModule, *args: Any, **kwargs: Any +) -> tuple[torch.Tensor]: + assert not kwargs + + unwrapped_inputs = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + out = local_map_hop(fw_gm, *unwrapped_inputs) + return ctx.wrap_tensors(out) + + +@local_map_hop.py_impl(FakeTensorMode) +def fake_mode_key( + mode: FakeTensorMode, + fw_gm: GraphModule, + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor]: + with mode: + return fw_gm(*args, **kwargs) + + +def proxy_mode_key_common( + call_hop: Callable[..., Any], + proxy_mode: ProxyTorchDispatchMode, + gm: GraphModule, + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor]: + assert proxy_mode is not None, ( + "Mode should always be enabled for python fallback key" + ) + assert len(kwargs) == 0 + + example_out = call_hop(*args, **kwargs) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr] + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", call_hop, proxy_args, {} + ) + + # extract local_map args, post-dispatch operates on GraphModules + assert gm.meta["local_map_kwargs"] + local_map_kwargs = gm.meta["local_map_kwargs"] + + # propagate local_map args to the call_function node + out_proxy.node.meta["local_map_kwargs"] = local_map_kwargs + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@local_map_hop.py_impl(ProxyTorchDispatchMode) +def proxy_mode_key( + proxy_mode: ProxyTorchDispatchMode, + fw_gm: GraphModule, + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor]: + # TODO: get rid of this when we can install as a subgraph + def call_local_map(*_args: Any, **_kwargs: Any) -> Any: + return functools.partial(local_map_hop, fw_gm)(*_args, **_kwargs) + + return proxy_mode_key_common(call_local_map, proxy_mode, fw_gm, *args, **kwargs) + + +# Running HOP in eager with real tensors +@local_map_hop.py_impl(DispatchKey.CompositeExplicitAutograd) +def real_impl( + fw_gm: GraphModule, + *args: Any, + **kwargs: Any, +) -> tuple[torch.Tensor]: + return fw_gm(*args, **kwargs) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 62e8c68e9be..80a0491f694 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -490,7 +490,9 @@ def propagate_shape_and_sharding( - An output dimension that is a split of the input dimension can only be sharded if the leftmost split size is divisible by the mesh dimension """ - assert len(input_src_placements) == len(mesh_sizes) + assert len(input_src_placements) == len(mesh_sizes), ( + f"{input_src_placements} != {mesh_sizes}" + ) # for each input dim, for each mesh dim, provides a list of possible shardable dimensions mesh_ndim = len(mesh_sizes) shardable_dims: dict[int, list[bool]] = {} diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 2a088340889..fc6cfa8cf7f 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -212,6 +212,32 @@ def simple_while_loop_stack_output(iter_t, x): return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple()) +def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): + # TODO: once HOPs support DTensor inputs, we should also test DTensors + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=False + ) + yield SampleInput( + make_arg(2, 3, 4, low=0.1, high=2), + make_arg(2, 3, 4, low=0.1, high=2), + ) + + +def simple_local_map_hop(inp1, inp2): + def body_gm(inp1, inp2): + return inp1.cos() + inp2.sin() + gm = torch.fx.symbolic_trace(body_gm) + + assert torch.distributed.is_available() + from torch.distributed.tensor.placement_types import Replicate + gm.meta["local_map_kwargs"] = { + "in_placements": (Replicate(), Replicate(), Replicate()), + "out_placements": ((Replicate(), Replicate(), Replicate()),) + } + + # TODO: Dynamo would rewrite this op differently + return torch._higher_order_ops.local_map_hop(gm, inp1, inp2) + def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -451,4 +477,25 @@ hop_db = [ ), decorators=[onlyCUDA], ), + OpInfo( + name="local_map_hop", + variant_test_name="simple", + op=simple_local_map_hop, + sample_inputs_func=sample_inputs_local_map_hop, + dtypes=custom_types(torch.float16, torch.float32), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")], + ), ]