mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
c9f16f201a
commit
821458d97a
|
|
@ -7197,6 +7197,7 @@ xfail_hops_compile = {
|
|||
# aot_eager
|
||||
"map", # assert type(args[1].realize()) is TensorVariable
|
||||
"scan", # scan is not an OpOverload
|
||||
"local_map_hop", # can't retrace
|
||||
# inductor
|
||||
"while_loop", # LoweringException: AssertionError
|
||||
"flex_attention", # LoweringException: AssertionError
|
||||
|
|
|
|||
203
test/higher_order_ops/test_local_map.py
Normal file
203
test/higher_order_ops/test_local_map.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# Owner(s): ["module: higher order operators"]
|
||||
# flake8: noqa: B950
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._functorch
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed._tensor.experimental import local_map
|
||||
from torch.distributed.tensor.placement_types import Replicate, Shard
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
nested_compile_region = torch.compiler.nested_compile_region
|
||||
|
||||
|
||||
class MyTransform(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x + 100
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad + 100
|
||||
|
||||
|
||||
def context_parallel_attention(query, key, value):
|
||||
out = F.scaled_dot_product_attention(
|
||||
query=query, key=key, value=value, is_causal=False
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def create_model(attention_fn, nheads, dim1, dim2):
|
||||
class LocalMapTransformerBlock(nn.Module):
|
||||
def __init__(self, nheads, dim1, dim2):
|
||||
super().__init__()
|
||||
self.nheads = nheads
|
||||
bias = False
|
||||
self.wq = nn.Linear(dim1, dim1, bias=bias)
|
||||
self.wk = nn.Linear(dim1, dim1, bias=bias)
|
||||
self.wv = nn.Linear(dim1, dim1, bias=bias)
|
||||
self.wo = nn.Linear(dim1, dim1, bias=bias)
|
||||
self.w1 = nn.Linear(dim1, dim2, bias=bias)
|
||||
self.w2 = nn.Linear(dim2, dim1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
q = self.wq(x)
|
||||
k = self.wk(x)
|
||||
v = self.wv(x)
|
||||
|
||||
q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
|
||||
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
|
||||
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
|
||||
|
||||
o = attention_fn(q, k, v)
|
||||
o = o.permute(0, 2, 1, 3).flatten(-2)
|
||||
|
||||
o = self.wo(o)
|
||||
|
||||
o0 = o + x
|
||||
|
||||
o = self.w1(o0)
|
||||
o = torch.nn.functional.relu(o)
|
||||
o = self.w2(o)
|
||||
|
||||
o = o0 + o
|
||||
return o
|
||||
|
||||
return LocalMapTransformerBlock(nheads, dim1, dim2)
|
||||
|
||||
|
||||
class TestLocalMap(TestCase):
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(
|
||||
not torch.distributed.is_available(), "Torch distributed not available."
|
||||
)
|
||||
def test_simple(self):
|
||||
@local_map(
|
||||
out_placements=((Shard(0), Shard(1), Shard(2)),),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
def cp_decorated(query, key, value):
|
||||
return context_parallel_attention(query, key, value)
|
||||
|
||||
cp_function = local_map(
|
||||
context_parallel_attention,
|
||||
out_placements=(Shard(0), Shard(1), Shard(2)),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
bs = 8 * 1
|
||||
dim1 = 96
|
||||
dim2 = dim1 * 4
|
||||
nheads = 16
|
||||
seq_len = 16
|
||||
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
model = create_model(cp_decorated, nheads, dim1, dim2).cuda()
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
model = create_model(cp_function, nheads, dim1, dim2).cuda()
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertEqual(len(backend.graphs), 2)
|
||||
# should see local_map_hop in both
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_self_modules_wq_parameters_weight_: "f32[96, 96]", L_x_: "f32[8, 16, 96]", L_self_modules_wk_parameters_weight_: "f32[96, 96]", L_self_modules_wv_parameters_weight_: "f32[96, 96]", L_self_modules_wo_parameters_weight_: "f32[96, 96]", L_self_modules_w1_parameters_weight_: "f32[384, 96]", L_self_modules_w2_parameters_weight_: "f32[96, 384]"):
|
||||
l_self_modules_wq_parameters_weight_ = L_self_modules_wq_parameters_weight_
|
||||
l_x_ = L_x_
|
||||
l_self_modules_wk_parameters_weight_ = L_self_modules_wk_parameters_weight_
|
||||
l_self_modules_wv_parameters_weight_ = L_self_modules_wv_parameters_weight_
|
||||
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
|
||||
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
|
||||
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
|
||||
|
||||
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
|
||||
|
||||
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
|
||||
|
||||
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
|
||||
|
||||
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
|
||||
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
|
||||
|
||||
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
|
||||
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
|
||||
|
||||
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
|
||||
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
|
||||
|
||||
subgraph_0 = self.subgraph_0
|
||||
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
|
||||
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
|
||||
|
||||
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
|
||||
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
|
||||
|
||||
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
|
||||
|
||||
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
|
||||
|
||||
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
|
||||
|
||||
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
|
||||
|
||||
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
|
||||
|
||||
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
|
||||
return (o_6,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, q_1: "f32[8, 16, 16, 6]", k_1: "f32[8, 16, 16, 6]", v_1: "f32[8, 16, 16, 6]"):
|
||||
out: "f32[8, 16, 16, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
|
||||
return (out,)
|
||||
""",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -5354,7 +5354,7 @@ if torch.distributed.is_available() and HAS_CUDA_AND_TRITON:
|
|||
test_dtensor.TestDTensorCompile
|
||||
)
|
||||
|
||||
xfail_hops = {}
|
||||
xfail_hops = {"local_map_hop"}
|
||||
|
||||
|
||||
class TestCompiledAutogradOpInfo(TestCase):
|
||||
|
|
|
|||
|
|
@ -206,7 +206,10 @@ from .functions import (
|
|||
UserMethodVariable,
|
||||
WrapperUserFunctionVariable,
|
||||
)
|
||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
from .higher_order_ops import (
|
||||
LocalMapWrappedHigherOrderVariable,
|
||||
TorchHigherOrderOperatorVariable,
|
||||
)
|
||||
from .iter import ItertoolsVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import (
|
||||
|
|
@ -850,6 +853,8 @@ class VariableBuilder:
|
|||
return build_checkpoint_variable(source=self.source)
|
||||
elif is_invoke_subgraph(value):
|
||||
return build_invoke_subgraph_variable(source=self.source)
|
||||
elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value):
|
||||
return LocalMapWrappedHigherOrderVariable.build(source=self.source)
|
||||
elif isinstance(value, functools.partial):
|
||||
func_src = AttrSource(self.get_source(), "func")
|
||||
func_obj = VariableBuilder(self.tx, func_src)(value.func)
|
||||
|
|
|
|||
|
|
@ -3383,6 +3383,7 @@ class BaseHOPVariable(WrapHigherOrderVariable):
|
|||
lambda a: a.node.meta["example_value"],
|
||||
body_r.as_proxy(),
|
||||
)
|
||||
|
||||
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
|
||||
return _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
|
|
@ -3497,6 +3498,115 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
|||
)
|
||||
|
||||
|
||||
class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
|
||||
supports_input_mutation = False
|
||||
supports_aliasing = False
|
||||
|
||||
# Subclasses aren't supported by speculate_subgraph yet
|
||||
# So this HOP is only usable with plain tensors
|
||||
_enabled = False
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def enable(cls):
|
||||
"""Context manager to temporarily enable local map wrapping.
|
||||
Will be removed when speculate_subgraph supports subclass inputs:
|
||||
https://github.com/pytorch/pytorch/issues/161456.
|
||||
|
||||
Usage:
|
||||
with LocalMapWrappedHigherOrderVariable.enable_wrapping():
|
||||
# Code where should_wrap_in_hop will return True
|
||||
pass
|
||||
"""
|
||||
old_value = cls._enabled
|
||||
cls._enabled = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cls._enabled = old_value
|
||||
|
||||
@classmethod
|
||||
def should_wrap_in_hop(cls, value):
|
||||
if not torch.distributed.is_available():
|
||||
return False
|
||||
|
||||
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
|
||||
|
||||
# check is important to avoid subclass dispatch
|
||||
if type(value) != type(_local_map_wrapped):
|
||||
return False
|
||||
|
||||
return value == _local_map_wrapped and cls._enabled
|
||||
|
||||
@staticmethod
|
||||
def build(**options):
|
||||
return TorchHigherOrderOperatorVariable.make(
|
||||
torch._higher_order_ops.local_map_hop,
|
||||
**options,
|
||||
)
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def _call_function(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
"""
|
||||
Goal of this function is to rewrite local_map usage as a HOP:
|
||||
local_map(func, ...) -> local_map_hop(gm, ...)
|
||||
"""
|
||||
|
||||
(
|
||||
user_func,
|
||||
out_placements,
|
||||
in_placements,
|
||||
in_grad_placements,
|
||||
device_mesh,
|
||||
redistribute_inputs,
|
||||
*user_args,
|
||||
) = args
|
||||
|
||||
(
|
||||
p_args,
|
||||
p_kwargs,
|
||||
example_value,
|
||||
body_r,
|
||||
treespec,
|
||||
body_gmod,
|
||||
body_name,
|
||||
) = self.create_wrapped_node(
|
||||
tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph"
|
||||
)
|
||||
|
||||
# Treat as const, so we don't have to deal with Placement types in fx IR
|
||||
# Guarded with EQUALS_MATCH on local_map call's arguments
|
||||
body_gmod.meta["local_map_kwargs"] = {
|
||||
"out_placements": out_placements.value,
|
||||
"in_placements": in_placements.value,
|
||||
"redistribute_inputs": redistribute_inputs.value,
|
||||
"in_grad_placements": in_grad_placements.value,
|
||||
"device_mesh": device_mesh.value,
|
||||
}
|
||||
|
||||
assert len(p_kwargs) == 0
|
||||
|
||||
flat_example_value = pytree.tree_map_only(
|
||||
torch.fx.Proxy,
|
||||
lambda a: a.node.meta["example_value"],
|
||||
body_r.as_proxy(),
|
||||
)
|
||||
|
||||
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
|
||||
out = _call_function_and_unflatten_output(
|
||||
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make()
|
||||
_hop_name_to_variable_class = {
|
||||
"cond": CondHigherOrderVariable,
|
||||
|
|
@ -3525,4 +3635,5 @@ _hop_name_to_variable_class = {
|
|||
"auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable,
|
||||
"invoke_subgraph": InvokeSubgraphHigherOrderVariable,
|
||||
"custom_function_call": CustomFunctionHigherOrderOperatorVariable,
|
||||
"local_map_hop": LocalMapWrappedHigherOrderVariable,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from torch._higher_order_ops.flex_attention import (
|
|||
from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
|
||||
from torch._higher_order_ops.local_map import local_map_hop
|
||||
from torch._higher_order_ops.map import map
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._higher_order_ops.run_const_graph import run_const_graph
|
||||
|
|
@ -73,4 +74,5 @@ __all__ = [
|
|||
"aoti_call_delegate",
|
||||
"map",
|
||||
"while_loop_stack_output",
|
||||
"local_map_hop",
|
||||
]
|
||||
|
|
|
|||
327
torch/_higher_order_ops/local_map.py
Normal file
327
torch/_higher_order_ops/local_map.py
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# NOTE: this file may be removed once we move to a dynamo frontend
|
||||
|
||||
import functools
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
clone_outputs_aliasing_inputs,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
|
||||
|
||||
# Proxy the HOP instead of inlining into it
|
||||
_DEFER_INLINING = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def defer_inlining() -> Generator[None, None, None]:
|
||||
global _DEFER_INLINING
|
||||
prior = _DEFER_INLINING
|
||||
try:
|
||||
_DEFER_INLINING = True
|
||||
yield
|
||||
finally:
|
||||
_DEFER_INLINING = prior
|
||||
|
||||
|
||||
class LocalMapHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("local_map_hop")
|
||||
|
||||
def __call__(self, fw_gm: GraphModule, *args: Any, **kwargs: Any) -> Any:
|
||||
return super().__call__(fw_gm, *args, **kwargs)
|
||||
|
||||
|
||||
local_map_hop = LocalMapHOP()
|
||||
|
||||
|
||||
def create_hop_fw_bw(
|
||||
fw_gm: GraphModule,
|
||||
*_args: Any,
|
||||
) -> tuple[GraphModule, GraphModule, int, int, set[int]]:
|
||||
"""
|
||||
Traces a joint, applies passes and partitions it
|
||||
"""
|
||||
# Keeping these imports here
|
||||
# Avoid circular dependencies once we upstream with dynamo frontend
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing, make_fx
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
bw_compiler=None, # type: ignore[arg-type]
|
||||
partition_fn=None, # type: ignore[arg-type]
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
# create a tensor (fake) from a compiler wrapped FunctionalTensor
|
||||
def _from_fun(t: Any) -> Any:
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
)
|
||||
return t
|
||||
|
||||
# If someone runs this hop under the default compiler backend ("eager")
|
||||
# Then this path will be run with the actual user inputs. We convert them
|
||||
# to fake tensors in order to not perform any actual compute.
|
||||
|
||||
fake_mode = detect_fake_mode(_args)
|
||||
if fake_mode is None:
|
||||
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
|
||||
with fake_mode:
|
||||
fw_inputs = pytree.tree_map(_from_fun, _args)
|
||||
|
||||
assert all(
|
||||
isinstance(t, (FakeTensor, int, torch.SymInt)) for t in fw_inputs
|
||||
), f"Unexpected element in {fw_inputs=}"
|
||||
|
||||
example_grads = pytree.tree_map(
|
||||
_from_fun,
|
||||
fw_gm(*fw_inputs),
|
||||
)
|
||||
if not isinstance(example_grads, (list, tuple)):
|
||||
example_grads = [example_grads]
|
||||
|
||||
num_fw_inputs = len(fw_inputs)
|
||||
num_fw_outputs = len(example_grads)
|
||||
|
||||
def joint_f(
|
||||
*primals_and_tangents: list[torch.Tensor],
|
||||
) -> Any:
|
||||
primals = primals_and_tangents[:num_fw_inputs]
|
||||
tangents = primals_and_tangents[num_fw_inputs:]
|
||||
|
||||
def prepare_fw_with_masks(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]:
|
||||
fw_out = fn(*args)
|
||||
assert isinstance(fw_out, tuple), (
|
||||
"Dynamo traced submodule should return tuple"
|
||||
)
|
||||
return fw_out, [
|
||||
True
|
||||
if isinstance(ret, torch.Tensor) and ret.requires_grad
|
||||
else False
|
||||
for ret in fw_out
|
||||
]
|
||||
|
||||
return fw_with_masks
|
||||
|
||||
fw_outs, grads = create_joint(
|
||||
prepare_fw_with_masks(fw_gm), aot_config=dummy_aot_config
|
||||
)(primals, tangents)
|
||||
|
||||
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
|
||||
# put grads first to work with existing hop utils
|
||||
return pytree.tree_map(maybe_clone, (*grads, *fw_outs))
|
||||
|
||||
filtered_grads_idx = set()
|
||||
for i, example_grad in enumerate(example_grads):
|
||||
# Filter out grads that are None or do not require_grad.
|
||||
# The AOTAutograd utils we rely on force this assumption.
|
||||
# We must also filter the runtime tangents too.
|
||||
if example_grad is not None and (
|
||||
isinstance(example_grad, torch.Tensor) and example_grad.requires_grad
|
||||
):
|
||||
filtered_grads_idx.add(i)
|
||||
|
||||
primals_and_tangents = [
|
||||
*fw_inputs,
|
||||
*[example_grads[i] for i in filtered_grads_idx],
|
||||
]
|
||||
joint_hop_gm = make_fx(joint_f)(*primals_and_tangents)
|
||||
|
||||
from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner
|
||||
from torch._inductor.compile_fx import partition_fn
|
||||
|
||||
# Match partitioner convention
|
||||
prepped_joint_hop_gm = prepare_for_partitioner(
|
||||
joint_hop_gm, num_fw_inputs, num_fw_outputs
|
||||
)
|
||||
# Also runs joint passes
|
||||
new_fw_gm, new_bw_gm = partition_fn(
|
||||
prepped_joint_hop_gm,
|
||||
[],
|
||||
num_fwd_outputs=num_fw_outputs,
|
||||
static_lifetime_input_indices=[],
|
||||
)
|
||||
|
||||
# Propagate meta onto fw/bw graphs, later will be set on proxied nodes
|
||||
local_map_kwargs = fw_gm.meta["local_map_kwargs"] # type: ignore[attr-defined]
|
||||
|
||||
new_fw_gm.meta["local_map_kwargs"] = local_map_kwargs
|
||||
new_bw_gm.meta["local_map_kwargs"] = {**local_map_kwargs}
|
||||
# Okay because Autoparallel assumes same sharding between param and grads
|
||||
new_bw_gm.meta["local_map_kwargs"]["in_placements"] = local_map_kwargs[
|
||||
"out_placements"
|
||||
]
|
||||
new_bw_gm.meta["local_map_kwargs"]["out_placements"] = local_map_kwargs[
|
||||
"in_placements"
|
||||
]
|
||||
|
||||
return new_fw_gm, new_bw_gm, num_fw_inputs, num_fw_outputs, filtered_grads_idx
|
||||
|
||||
|
||||
class LocalMapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
fw_gm: GraphModule,
|
||||
bw_gm: GraphModule,
|
||||
num_fw_ins: int,
|
||||
num_fw_outs: int,
|
||||
filtered_grads_idx: set[int],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
ctx.bw_gm = bw_gm
|
||||
ctx.num_fw_ins = num_fw_ins
|
||||
ctx.filtered_grads_idx = filtered_grads_idx
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
fw_outs_with_saved_activations = local_map_hop(fw_gm, *args, **kwargs)
|
||||
|
||||
fw_outs = fw_outs_with_saved_activations[:num_fw_outs]
|
||||
saved_activations = fw_outs_with_saved_activations[num_fw_outs:]
|
||||
save_tensors_and_symints_for_backward(ctx, saved_activations)
|
||||
|
||||
return fw_outs
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: Any, *_grads: tuple[torch.Tensor]
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
saved_activations = saved_tensors_and_symints(ctx)
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
# Filter out grads that are None or do not require_grad.
|
||||
# The AOTAutograd utils we rely on force this assumption.
|
||||
grads = [_grads[i] for i in ctx.filtered_grads_idx]
|
||||
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
|
||||
if len(grad_ins) != ctx.num_fw_ins:
|
||||
raise RuntimeError(
|
||||
f"Expected {ctx.num_fw_ins} grad_ins, got {len(grad_ins)}"
|
||||
)
|
||||
return None, None, None, None, None, *grad_ins
|
||||
|
||||
|
||||
@local_map_hop.py_impl(torch._C.DispatchKey.Autograd)
|
||||
def autograd_key(
|
||||
fw_gm: GraphModule,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if _DEFER_INLINING:
|
||||
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx = create_hop_fw_bw(
|
||||
fw_gm, *args
|
||||
)
|
||||
return LocalMapAutogradOp.apply(
|
||||
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs
|
||||
)
|
||||
|
||||
return fw_gm(*args, **kwargs)
|
||||
|
||||
|
||||
@local_map_hop.py_functionalize_impl
|
||||
def functional_mode_key(
|
||||
ctx: Any, fw_gm: GraphModule, *args: Any, **kwargs: Any
|
||||
) -> tuple[torch.Tensor]:
|
||||
assert not kwargs
|
||||
|
||||
unwrapped_inputs = ctx.unwrap_tensors(args)
|
||||
with ctx.redispatch_to_next():
|
||||
out = local_map_hop(fw_gm, *unwrapped_inputs)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
|
||||
@local_map_hop.py_impl(FakeTensorMode)
|
||||
def fake_mode_key(
|
||||
mode: FakeTensorMode,
|
||||
fw_gm: GraphModule,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[torch.Tensor]:
|
||||
with mode:
|
||||
return fw_gm(*args, **kwargs)
|
||||
|
||||
|
||||
def proxy_mode_key_common(
|
||||
call_hop: Callable[..., Any],
|
||||
proxy_mode: ProxyTorchDispatchMode,
|
||||
gm: GraphModule,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[torch.Tensor]:
|
||||
assert proxy_mode is not None, (
|
||||
"Mode should always be enabled for python fallback key"
|
||||
)
|
||||
assert len(kwargs) == 0
|
||||
|
||||
example_out = call_hop(*args, **kwargs)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr]
|
||||
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", call_hop, proxy_args, {}
|
||||
)
|
||||
|
||||
# extract local_map args, post-dispatch operates on GraphModules
|
||||
assert gm.meta["local_map_kwargs"]
|
||||
local_map_kwargs = gm.meta["local_map_kwargs"]
|
||||
|
||||
# propagate local_map args to the call_function node
|
||||
out_proxy.node.meta["local_map_kwargs"] = local_map_kwargs
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
||||
|
||||
@local_map_hop.py_impl(ProxyTorchDispatchMode)
|
||||
def proxy_mode_key(
|
||||
proxy_mode: ProxyTorchDispatchMode,
|
||||
fw_gm: GraphModule,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# TODO: get rid of this when we can install as a subgraph
|
||||
def call_local_map(*_args: Any, **_kwargs: Any) -> Any:
|
||||
return functools.partial(local_map_hop, fw_gm)(*_args, **_kwargs)
|
||||
|
||||
return proxy_mode_key_common(call_local_map, proxy_mode, fw_gm, *args, **kwargs)
|
||||
|
||||
|
||||
# Running HOP in eager with real tensors
|
||||
@local_map_hop.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def real_impl(
|
||||
fw_gm: GraphModule,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[torch.Tensor]:
|
||||
return fw_gm(*args, **kwargs)
|
||||
|
|
@ -490,7 +490,9 @@ def propagate_shape_and_sharding(
|
|||
- An output dimension that is a split of the input dimension can only be sharded
|
||||
if the leftmost split size is divisible by the mesh dimension
|
||||
"""
|
||||
assert len(input_src_placements) == len(mesh_sizes)
|
||||
assert len(input_src_placements) == len(mesh_sizes), (
|
||||
f"{input_src_placements} != {mesh_sizes}"
|
||||
)
|
||||
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
|
||||
mesh_ndim = len(mesh_sizes)
|
||||
shardable_dims: dict[int, list[bool]] = {}
|
||||
|
|
|
|||
|
|
@ -212,6 +212,32 @@ def simple_while_loop_stack_output(iter_t, x):
|
|||
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
|
||||
|
||||
|
||||
def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
# TODO: once HOPs support DTensor inputs, we should also test DTensors
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=False
|
||||
)
|
||||
yield SampleInput(
|
||||
make_arg(2, 3, 4, low=0.1, high=2),
|
||||
make_arg(2, 3, 4, low=0.1, high=2),
|
||||
)
|
||||
|
||||
|
||||
def simple_local_map_hop(inp1, inp2):
|
||||
def body_gm(inp1, inp2):
|
||||
return inp1.cos() + inp2.sin()
|
||||
gm = torch.fx.symbolic_trace(body_gm)
|
||||
|
||||
assert torch.distributed.is_available()
|
||||
from torch.distributed.tensor.placement_types import Replicate
|
||||
gm.meta["local_map_kwargs"] = {
|
||||
"in_placements": (Replicate(), Replicate(), Replicate()),
|
||||
"out_placements": ((Replicate(), Replicate(), Replicate()),)
|
||||
}
|
||||
|
||||
# TODO: Dynamo would rewrite this op differently
|
||||
return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
|
||||
|
||||
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
|
|
@ -451,4 +477,25 @@ hop_db = [
|
|||
),
|
||||
decorators=[onlyCUDA],
|
||||
),
|
||||
OpInfo(
|
||||
name="local_map_hop",
|
||||
variant_test_name="simple",
|
||||
op=simple_local_map_hop,
|
||||
sample_inputs_func=sample_inputs_local_map_hop,
|
||||
dtypes=custom_types(torch.float16, torch.float32),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_gradgrad=False,
|
||||
check_batched_forward_grad=False,
|
||||
check_inplace_batched_forward_grad=False,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
|
||||
),
|
||||
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
|
||||
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
|
||||
),
|
||||
decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user