[dynamo][hop] Introduce Local Map HOP (#161458)

Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
This commit is contained in:
Simon Fan 2025-09-16 14:03:49 -07:00 committed by PyTorch MergeBot
parent c9f16f201a
commit 821458d97a
9 changed files with 701 additions and 3 deletions

View File

@ -7197,6 +7197,7 @@ xfail_hops_compile = {
# aot_eager
"map", # assert type(args[1].realize()) is TensorVariable
"scan", # scan is not an OpOverload
"local_map_hop", # can't retrace
# inductor
"while_loop", # LoweringException: AssertionError
"flex_attention", # LoweringException: AssertionError

View File

@ -0,0 +1,203 @@
# Owner(s): ["module: higher order operators"]
# flake8: noqa: B950
import unittest
import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
import torch.nn.functional as F
from torch import nn
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
if torch.distributed.is_available():
from torch.distributed._tensor.experimental import local_map
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase
from torch.testing._internal.triton_utils import requires_cuda_and_triton
nested_compile_region = torch.compiler.nested_compile_region
class MyTransform(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x + 100
@staticmethod
def backward(ctx, grad):
return grad + 100
def context_parallel_attention(query, key, value):
out = F.scaled_dot_product_attention(
query=query, key=key, value=value, is_causal=False
)
return out
def create_model(attention_fn, nheads, dim1, dim2):
class LocalMapTransformerBlock(nn.Module):
def __init__(self, nheads, dim1, dim2):
super().__init__()
self.nheads = nheads
bias = False
self.wq = nn.Linear(dim1, dim1, bias=bias)
self.wk = nn.Linear(dim1, dim1, bias=bias)
self.wv = nn.Linear(dim1, dim1, bias=bias)
self.wo = nn.Linear(dim1, dim1, bias=bias)
self.w1 = nn.Linear(dim1, dim2, bias=bias)
self.w2 = nn.Linear(dim2, dim1, bias=bias)
def forward(self, x):
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
o = attention_fn(q, k, v)
o = o.permute(0, 2, 1, 3).flatten(-2)
o = self.wo(o)
o0 = o + x
o = self.w1(o0)
o = torch.nn.functional.relu(o)
o = self.w2(o)
o = o0 + o
return o
return LocalMapTransformerBlock(nheads, dim1, dim2)
class TestLocalMap(TestCase):
@requires_cuda_and_triton
@unittest.skipIf(
not torch.distributed.is_available(), "Torch distributed not available."
)
def test_simple(self):
@local_map(
out_placements=((Shard(0), Shard(1), Shard(2)),),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
def cp_decorated(query, key, value):
return context_parallel_attention(query, key, value)
cp_function = local_map(
context_parallel_attention,
out_placements=(Shard(0), Shard(1), Shard(2)),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
bs = 8 * 1
dim1 = 96
dim2 = dim1 * 4
nheads = 16
seq_len = 16
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
backend = EagerAndRecordGraphs()
model = create_model(cp_decorated, nheads, dim1, dim2).cuda()
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
model = create_model(cp_function, nheads, dim1, dim2).cuda()
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
if not TEST_WITH_CROSSREF:
self.assertEqual(len(backend.graphs), 2)
# should see local_map_hop in both
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_wq_parameters_weight_: "f32[96, 96]", L_x_: "f32[8, 16, 96]", L_self_modules_wk_parameters_weight_: "f32[96, 96]", L_self_modules_wv_parameters_weight_: "f32[96, 96]", L_self_modules_wo_parameters_weight_: "f32[96, 96]", L_self_modules_w1_parameters_weight_: "f32[384, 96]", L_self_modules_w2_parameters_weight_: "f32[96, 384]"):
l_self_modules_wq_parameters_weight_ = L_self_modules_wq_parameters_weight_
l_x_ = L_x_
l_self_modules_wk_parameters_weight_ = L_self_modules_wk_parameters_weight_
l_self_modules_wv_parameters_weight_ = L_self_modules_wv_parameters_weight_
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
subgraph_0 = self.subgraph_0
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
return (o_6,)
class subgraph_0(torch.nn.Module):
def forward(self, q_1: "f32[8, 16, 16, 6]", k_1: "f32[8, 16, 16, 6]", v_1: "f32[8, 16, 16, 6]"):
out: "f32[8, 16, 16, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
return (out,)
""",
)
self.assertEqual(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
)
if __name__ == "__main__":
run_tests()

View File

@ -5354,7 +5354,7 @@ if torch.distributed.is_available() and HAS_CUDA_AND_TRITON:
test_dtensor.TestDTensorCompile
)
xfail_hops = {}
xfail_hops = {"local_map_hop"}
class TestCompiledAutogradOpInfo(TestCase):

View File

@ -206,7 +206,10 @@ from .functions import (
UserMethodVariable,
WrapperUserFunctionVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .higher_order_ops import (
LocalMapWrappedHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import ItertoolsVariable
from .lazy import LazyVariableTracker
from .lists import (
@ -850,6 +853,8 @@ class VariableBuilder:
return build_checkpoint_variable(source=self.source)
elif is_invoke_subgraph(value):
return build_invoke_subgraph_variable(source=self.source)
elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value):
return LocalMapWrappedHigherOrderVariable.build(source=self.source)
elif isinstance(value, functools.partial):
func_src = AttrSource(self.get_source(), "func")
func_obj = VariableBuilder(self.tx, func_src)(value.func)

View File

@ -3383,6 +3383,7 @@ class BaseHOPVariable(WrapHigherOrderVariable):
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
return _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
@ -3497,6 +3498,115 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
)
class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
supports_input_mutation = False
supports_aliasing = False
# Subclasses aren't supported by speculate_subgraph yet
# So this HOP is only usable with plain tensors
_enabled = False
@classmethod
@contextlib.contextmanager
def enable(cls):
"""Context manager to temporarily enable local map wrapping.
Will be removed when speculate_subgraph supports subclass inputs:
https://github.com/pytorch/pytorch/issues/161456.
Usage:
with LocalMapWrappedHigherOrderVariable.enable_wrapping():
# Code where should_wrap_in_hop will return True
pass
"""
old_value = cls._enabled
cls._enabled = True
try:
yield
finally:
cls._enabled = old_value
@classmethod
def should_wrap_in_hop(cls, value):
if not torch.distributed.is_available():
return False
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
# check is important to avoid subclass dispatch
if type(value) != type(_local_map_wrapped):
return False
return value == _local_map_wrapped and cls._enabled
@staticmethod
def build(**options):
return TorchHigherOrderOperatorVariable.make(
torch._higher_order_ops.local_map_hop,
**options,
)
def python_type(self):
return type(self.value)
def _call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
"""
Goal of this function is to rewrite local_map usage as a HOP:
local_map(func, ...) -> local_map_hop(gm, ...)
"""
(
user_func,
out_placements,
in_placements,
in_grad_placements,
device_mesh,
redistribute_inputs,
*user_args,
) = args
(
p_args,
p_kwargs,
example_value,
body_r,
treespec,
body_gmod,
body_name,
) = self.create_wrapped_node(
tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph"
)
# Treat as const, so we don't have to deal with Placement types in fx IR
# Guarded with EQUALS_MATCH on local_map call's arguments
body_gmod.meta["local_map_kwargs"] = {
"out_placements": out_placements.value,
"in_placements": in_placements.value,
"redistribute_inputs": redistribute_inputs.value,
"in_grad_placements": in_grad_placements.value,
"device_mesh": device_mesh.value,
}
assert len(p_kwargs) == 0
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
out = _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
)
return out
# Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make()
_hop_name_to_variable_class = {
"cond": CondHigherOrderVariable,
@ -3525,4 +3635,5 @@ _hop_name_to_variable_class = {
"auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable,
"invoke_subgraph": InvokeSubgraphHigherOrderVariable,
"custom_function_call": CustomFunctionHigherOrderOperatorVariable,
"local_map_hop": LocalMapWrappedHigherOrderVariable,
}

View File

@ -21,6 +21,7 @@ from torch._higher_order_ops.flex_attention import (
from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
from torch._higher_order_ops.local_map import local_map_hop
from torch._higher_order_ops.map import map
from torch._higher_order_ops.out_dtype import out_dtype
from torch._higher_order_ops.run_const_graph import run_const_graph
@ -73,4 +74,5 @@ __all__ = [
"aoti_call_delegate",
"map",
"while_loop_stack_output",
"local_map_hop",
]

View File

@ -0,0 +1,327 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# NOTE: this file may be removed once we move to a dynamo frontend
import functools
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Callable, Optional
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
clone_outputs_aliasing_inputs,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
# Proxy the HOP instead of inlining into it
_DEFER_INLINING = False
@contextmanager
def defer_inlining() -> Generator[None, None, None]:
global _DEFER_INLINING
prior = _DEFER_INLINING
try:
_DEFER_INLINING = True
yield
finally:
_DEFER_INLINING = prior
class LocalMapHOP(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("local_map_hop")
def __call__(self, fw_gm: GraphModule, *args: Any, **kwargs: Any) -> Any:
return super().__call__(fw_gm, *args, **kwargs)
local_map_hop = LocalMapHOP()
def create_hop_fw_bw(
fw_gm: GraphModule,
*_args: Any,
) -> tuple[GraphModule, GraphModule, int, int, set[int]]:
"""
Traces a joint, applies passes and partitions it
"""
# Keeping these imports here
# Avoid circular dependencies once we upstream with dynamo frontend
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
# create a tensor (fake) from a compiler wrapped FunctionalTensor
def _from_fun(t: Any) -> Any:
if isinstance(t, torch.Tensor):
return torch.empty_strided(
t.size(),
t.stride(),
device=t.device,
dtype=t.dtype,
requires_grad=t.requires_grad,
)
return t
# If someone runs this hop under the default compiler backend ("eager")
# Then this path will be run with the actual user inputs. We convert them
# to fake tensors in order to not perform any actual compute.
fake_mode = detect_fake_mode(_args)
if fake_mode is None:
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
with fake_mode:
fw_inputs = pytree.tree_map(_from_fun, _args)
assert all(
isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs
), f"Unexpected element in {fw_inputs=}"
example_grads = pytree.tree_map(
_from_fun,
fw_gm(*fw_inputs),
)
if not isinstance(example_grads, (list, tuple)):
example_grads = [example_grads]
num_fw_inputs = len(fw_inputs)
num_fw_outputs = len(example_grads)
def joint_f(
*primals_and_tangents: list[torch.Tensor],
) -> Any:
primals = primals_and_tangents[:num_fw_inputs]
tangents = primals_and_tangents[num_fw_inputs:]
def prepare_fw_with_masks(fn: Callable[..., Any]) -> Callable[..., Any]:
def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]:
fw_out = fn(*args)
assert isinstance(fw_out, tuple), (
"Dynamo traced submodule should return tuple"
)
return fw_out, [
True
if isinstance(ret, torch.Tensor) and ret.requires_grad
else False
for ret in fw_out
]
return fw_with_masks
fw_outs, grads = create_joint(
prepare_fw_with_masks(fw_gm), aot_config=dummy_aot_config
)(primals, tangents)
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
# put grads first to work with existing hop utils
return pytree.tree_map(maybe_clone, (*grads, *fw_outs))
filtered_grads_idx = set()
for i, example_grad in enumerate(example_grads):
# Filter out grads that are None or do not require_grad.
# The AOTAutograd utils we rely on force this assumption.
# We must also filter the runtime tangents too.
if example_grad is not None and (
isinstance(example_grad, torch.Tensor) and example_grad.requires_grad
):
filtered_grads_idx.add(i)
primals_and_tangents = [
*fw_inputs,
*[example_grads[i] for i in filtered_grads_idx],
]
joint_hop_gm = make_fx(joint_f)(*primals_and_tangents)
from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner
from torch._inductor.compile_fx import partition_fn
# Match partitioner convention
prepped_joint_hop_gm = prepare_for_partitioner(
joint_hop_gm, num_fw_inputs, num_fw_outputs
)
# Also runs joint passes
new_fw_gm, new_bw_gm = partition_fn(
prepped_joint_hop_gm,
[],
num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=[],
)
# Propagate meta onto fw/bw graphs, later will be set on proxied nodes
local_map_kwargs = fw_gm.meta["local_map_kwargs"] # type: ignore[attr-defined]
new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs
new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs}
# Okay because Autoparallel assumes same sharding between param and grads
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = local_map_kwargs[
"out_placements"
]
new_bw_gm.meta["local_map_kwargs"]["out_placements"] = local_map_kwargs[
"in_placements"
]
return new_fw_gm, new_bw_gm, num_fw_inputs, num_fw_outputs, filtered_grads_idx
class LocalMapAutogradOp(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
fw_gm: GraphModule,
bw_gm: GraphModule,
num_fw_ins: int,
num_fw_outs: int,
filtered_grads_idx: set[int],
*args: Any,
**kwargs: Any,
) -> tuple[Optional[torch.Tensor], ...]:
ctx.bw_gm = bw_gm
ctx.num_fw_ins = num_fw_ins
ctx.filtered_grads_idx = filtered_grads_idx
with torch._C._AutoDispatchBelowAutograd():
fw_outs_with_saved_activations = local_map_hop(fw_gm, *args, **kwargs)
fw_outs = fw_outs_with_saved_activations[:num_fw_outs]
saved_activations = fw_outs_with_saved_activations[num_fw_outs:]
save_tensors_and_symints_for_backward(ctx, saved_activations)
return fw_outs
@staticmethod
def backward(
ctx: Any, *_grads: tuple[torch.Tensor]
) -> tuple[Optional[torch.Tensor], ...]:
saved_activations = saved_tensors_and_symints(ctx)
with torch._C._AutoDispatchBelowAutograd():
# Filter out grads that are None or do not require_grad.
# The AOTAutograd utils we rely on force this assumption.
grads = [_grads[i] for i in ctx.filtered_grads_idx]
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
if len(grad_ins) != ctx.num_fw_ins:
raise RuntimeError(
f"Expected {ctx.num_fw_ins} grad_ins, got {len(grad_ins)}"
)
return None, None, None, None, None, *grad_ins
@local_map_hop.py_impl(torch._C.DispatchKey.Autograd)
def autograd_key(
fw_gm: GraphModule,
*args: Any,
**kwargs: Any,
) -> Any:
if _DEFER_INLINING:
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx = create_hop_fw_bw(
fw_gm, *args
)
return LocalMapAutogradOp.apply(
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs
)
return fw_gm(*args, **kwargs)
@local_map_hop.py_functionalize_impl
def functional_mode_key(
ctx: Any, fw_gm: GraphModule, *args: Any, **kwargs: Any
) -> tuple[torch.Tensor]:
assert not kwargs
unwrapped_inputs = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
out = local_map_hop(fw_gm, *unwrapped_inputs)
return ctx.wrap_tensors(out)
@local_map_hop.py_impl(FakeTensorMode)
def fake_mode_key(
mode: FakeTensorMode,
fw_gm: GraphModule,
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor]:
with mode:
return fw_gm(*args, **kwargs)
def proxy_mode_key_common(
call_hop: Callable[..., Any],
proxy_mode: ProxyTorchDispatchMode,
gm: GraphModule,
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor]:
assert proxy_mode is not None, (
"Mode should always be enabled for python fallback key"
)
assert len(kwargs) == 0
example_out = call_hop(*args, **kwargs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr]
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", call_hop, proxy_args, {}
)
# extract local_map args, post-dispatch operates on GraphModules
assert gm.meta["local_map_kwargs"]
local_map_kwargs = gm.meta["local_map_kwargs"]
# propagate local_map args to the call_function node
out_proxy.node.meta["local_map_kwargs"] = local_map_kwargs
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
)
@local_map_hop.py_impl(ProxyTorchDispatchMode)
def proxy_mode_key(
proxy_mode: ProxyTorchDispatchMode,
fw_gm: GraphModule,
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor]:
# TODO: get rid of this when we can install as a subgraph
def call_local_map(*_args: Any, **_kwargs: Any) -> Any:
return functools.partial(local_map_hop, fw_gm)(*_args, **_kwargs)
return proxy_mode_key_common(call_local_map, proxy_mode, fw_gm, *args, **kwargs)
# Running HOP in eager with real tensors
@local_map_hop.py_impl(DispatchKey.CompositeExplicitAutograd)
def real_impl(
fw_gm: GraphModule,
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor]:
return fw_gm(*args, **kwargs)

View File

@ -490,7 +490,9 @@ def propagate_shape_and_sharding(
- An output dimension that is a split of the input dimension can only be sharded
if the leftmost split size is divisible by the mesh dimension
"""
assert len(input_src_placements) == len(mesh_sizes)
assert len(input_src_placements) == len(mesh_sizes), (
f"{input_src_placements} != {mesh_sizes}"
)
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
mesh_ndim = len(mesh_sizes)
shardable_dims: dict[int, list[bool]] = {}

View File

@ -212,6 +212,32 @@ def simple_while_loop_stack_output(iter_t, x):
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
# TODO: once HOPs support DTensor inputs, we should also test DTensors
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=False
)
yield SampleInput(
make_arg(2, 3, 4, low=0.1, high=2),
make_arg(2, 3, 4, low=0.1, high=2),
)
def simple_local_map_hop(inp1, inp2):
def body_gm(inp1, inp2):
return inp1.cos() + inp2.sin()
gm = torch.fx.symbolic_trace(body_gm)
assert torch.distributed.is_available()
from torch.distributed.tensor.placement_types import Replicate
gm.meta["local_map_kwargs"] = {
"in_placements": (Replicate(), Replicate(), Replicate()),
"out_placements": ((Replicate(), Replicate(), Replicate()),)
}
# TODO: Dynamo would rewrite this op differently
return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
@ -451,4 +477,25 @@ hop_db = [
),
decorators=[onlyCUDA],
),
OpInfo(
name="local_map_hop",
variant_test_name="simple",
op=simple_local_map_hop,
sample_inputs_func=sample_inputs_local_map_hop,
dtypes=custom_types(torch.float16, torch.float32),
supports_out=False,
check_batched_grad=False,
check_batched_gradgrad=False,
check_batched_forward_grad=False,
check_inplace_batched_forward_grad=False,
skips=(
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
DecorateInfo(
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")],
),
]