mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type annotations for higher order ops/flex_attention (#137065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137065 Approved by: https://github.com/drisspg, https://github.com/Skylion007 ghstack dependencies: #136826, #137043, #137049
This commit is contained in:
parent
3b8511dadf
commit
2854d157de
|
|
@ -586,7 +586,9 @@ def auto_functionalized_fake(
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||||
with mode:
|
with mode:
|
||||||
result = auto_functionalized_dense(_mutable_op, **kwargs)
|
result = auto_functionalized_dense(
|
||||||
|
_mutable_op, _only_clone_these_tensors=None, **kwargs
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -681,7 +683,9 @@ def auto_functionalized_v2_fake(
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Dict[str, Any],
|
||||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||||
with mode:
|
with mode:
|
||||||
result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
|
result = auto_functionalized_v2_dense(
|
||||||
|
_mutable_op, _only_clone_these_bases=None, **kwargs
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
# mypy: allow-untyped-decorators
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Dict, Sequence, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
from torch import Tensor
|
||||||
from torch._C import DispatchKey
|
from torch._C import DispatchKey
|
||||||
from torch._higher_order_ops.utils import (
|
from torch._higher_order_ops.utils import (
|
||||||
_has_potential_branch_input_mutation,
|
_has_potential_branch_input_mutation,
|
||||||
|
|
@ -12,7 +11,7 @@ from torch._higher_order_ops.utils import (
|
||||||
reenter_make_fx,
|
reenter_make_fx,
|
||||||
UnsupportedAliasMutationException,
|
UnsupportedAliasMutationException,
|
||||||
)
|
)
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator, OpOverload
|
||||||
from torch._subclasses import FakeTensorMode
|
from torch._subclasses import FakeTensorMode
|
||||||
from torch.fx.experimental.proxy_tensor import (
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
make_fx,
|
make_fx,
|
||||||
|
|
@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
|
||||||
# scalar and create a view. We do not want that behavior in this case, so we
|
# scalar and create a view. We do not want that behavior in this case, so we
|
||||||
# use this torchfunctionmode to override that behavior for score_mod
|
# use this torchfunctionmode to override that behavior for score_mod
|
||||||
# wherever we're running it.
|
# wherever we're running it.
|
||||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
def __torch_function__(
|
||||||
|
self,
|
||||||
|
func: OpOverload,
|
||||||
|
types: Tuple[torch._C._TensorMeta, ...],
|
||||||
|
args: Tuple[object, ...] = (),
|
||||||
|
kwargs: Optional[Dict[str, object]] = None,
|
||||||
|
) -> object:
|
||||||
if func == torch.Tensor.__getitem__:
|
if func == torch.Tensor.__getitem__:
|
||||||
index_args = pytree.tree_leaves(args[1])
|
index_args = pytree.tree_leaves(args[1])
|
||||||
if all(isinstance(x, torch.Tensor) for x in index_args):
|
if all(isinstance(x, torch.Tensor) for x in index_args):
|
||||||
|
|
@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------- Autograd Implementation ----------------------------
|
# ---------------------------- Autograd Implementation ----------------------------
|
||||||
def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
def create_fw_bw_graph(
|
||||||
|
score_mod: Callable,
|
||||||
|
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
|
||||||
|
other_buffers: Tuple[Tensor, ...],
|
||||||
|
) -> Tuple[Callable, Callable]:
|
||||||
# See Note:[HOP create fw_bw graph]
|
# See Note:[HOP create fw_bw graph]
|
||||||
|
|
||||||
# All of these imports need to be here in order to avoid circular dependencies
|
# All of these imports need to be here in order to avoid circular dependencies
|
||||||
|
|
@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
||||||
with suspend_functionalization(), disable_functional_mode():
|
with suspend_functionalization(), disable_functional_mode():
|
||||||
with disable_proxy_modes_tracing():
|
with disable_proxy_modes_tracing():
|
||||||
|
|
||||||
def _from_fun(t):
|
def _from_fun(t: Tensor) -> Tensor:
|
||||||
return torch.empty_strided(
|
return torch.empty_strided(
|
||||||
t.size(),
|
t.size(),
|
||||||
t.stride(),
|
t.stride(),
|
||||||
|
|
@ -544,8 +553,18 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
||||||
)
|
)
|
||||||
example_grad = _from_fun(example_flat_out)
|
example_grad = _from_fun(example_flat_out)
|
||||||
|
|
||||||
def joint_f(score, b, h, m, n, example_grad, *other_buffers):
|
def joint_f(
|
||||||
def fw_with_masks(*args):
|
score: Tensor,
|
||||||
|
b: Tensor,
|
||||||
|
h: Tensor,
|
||||||
|
m: Tensor,
|
||||||
|
n: Tensor,
|
||||||
|
example_grad: Tensor,
|
||||||
|
*other_buffers: Tuple[Tensor, ...],
|
||||||
|
) -> Tuple[Tensor, ...]:
|
||||||
|
def fw_with_masks(
|
||||||
|
*args: Tuple[Tensor, ...]
|
||||||
|
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
|
||||||
fw_out = score_mod(*args)
|
fw_out = score_mod(*args)
|
||||||
out_requires_grad = fw_out.requires_grad
|
out_requires_grad = fw_out.requires_grad
|
||||||
return ((fw_out,), (out_requires_grad,))
|
return ((fw_out,), (out_requires_grad,))
|
||||||
|
|
@ -566,17 +585,17 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
||||||
class FlexAttentionAutogradOp(torch.autograd.Function):
|
class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx: Any,
|
||||||
query,
|
query: Tensor,
|
||||||
key,
|
key: Tensor,
|
||||||
value,
|
value: Tensor,
|
||||||
fw_graph,
|
fw_graph: Callable,
|
||||||
joint_graph,
|
joint_graph: Callable,
|
||||||
block_mask,
|
block_mask: Tuple[Any, ...],
|
||||||
scale,
|
scale: float,
|
||||||
kernel_options,
|
kernel_options: Dict[str, Any],
|
||||||
score_mod_other_buffers,
|
score_mod_other_buffers: Tuple[Any, ...],
|
||||||
mask_mod_other_buffers,
|
mask_mod_other_buffers: Tuple[Any, ...],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
any_buffer_requires_grad = any(
|
any_buffer_requires_grad = any(
|
||||||
buffer.requires_grad
|
buffer.requires_grad
|
||||||
|
|
@ -620,7 +639,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||||
return out, logsumexp
|
return out, logsumexp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_out, grad_logsumexp):
|
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||||
fw_args = ctx.saved_tensors
|
fw_args = ctx.saved_tensors
|
||||||
(
|
(
|
||||||
query,
|
query,
|
||||||
|
|
@ -693,15 +712,19 @@ def flex_attention_autograd(
|
||||||
block_mask: Tuple,
|
block_mask: Tuple,
|
||||||
scale: float,
|
scale: float,
|
||||||
kernel_options: Dict[str, Any],
|
kernel_options: Dict[str, Any],
|
||||||
score_mod_other_buffers: Tuple = (),
|
score_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||||
mask_mod_other_buffers: Tuple = (),
|
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
with TransformGetItemToIndex():
|
with TransformGetItemToIndex():
|
||||||
input_requires_grad = any(t.requires_grad for t in (query, key, value))
|
input_requires_grad = any(t.requires_grad for t in (query, key, value))
|
||||||
if torch.is_grad_enabled() and input_requires_grad:
|
if torch.is_grad_enabled() and input_requires_grad:
|
||||||
example_vals = [
|
example_vals = (
|
||||||
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
|
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
|
||||||
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
|
torch.zeros((), dtype=torch.int),
|
||||||
|
torch.zeros((), dtype=torch.int),
|
||||||
|
torch.zeros((), dtype=torch.int),
|
||||||
|
torch.zeros((), dtype=torch.int),
|
||||||
|
)
|
||||||
fw_graph, bw_graph = create_fw_bw_graph(
|
fw_graph, bw_graph = create_fw_bw_graph(
|
||||||
score_mod, example_vals, score_mod_other_buffers
|
score_mod, example_vals, score_mod_other_buffers
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -845,7 +845,9 @@ def decompose_auto_functionalized(graph):
|
||||||
# tracing a function with kwargs.
|
# tracing a function with kwargs.
|
||||||
def decomp(*flat_args):
|
def decomp(*flat_args):
|
||||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||||
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
|
assert len(args) == 1
|
||||||
|
mode = args[0]
|
||||||
|
return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
|
||||||
|
|
||||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||||
|
|
||||||
|
|
@ -889,7 +891,11 @@ def decompose_auto_functionalized(graph):
|
||||||
# tracing a function with kwargs.
|
# tracing a function with kwargs.
|
||||||
def decomp(*flat_args):
|
def decomp(*flat_args):
|
||||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||||
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
|
assert len(args) == 1
|
||||||
|
mutable_op = args[0]
|
||||||
|
return auto_functionalized_v2_dense(
|
||||||
|
mutable_op, only_clone_these_bases, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Dict, List, Set, Type, Union
|
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
|
@ -16,6 +16,9 @@ from torch._functorch.pyfunctorch import dispatch_functorch
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
||||||
|
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
# Query `hasattr` only once.
|
# Query `hasattr` only once.
|
||||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||||
|
|
||||||
|
|
@ -99,8 +102,8 @@ class OperatorBase:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def py_impl(self, k):
|
def py_impl(self, k: Any) -> Callable[[_F], _F]:
|
||||||
def inner(fn):
|
def inner(fn: _F) -> _F:
|
||||||
if inspect.isclass(k) and (
|
if inspect.isclass(k) and (
|
||||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
||||||
):
|
):
|
||||||
|
|
@ -141,7 +144,7 @@ class OperatorBase:
|
||||||
# with ctx.redispatch_to_next():
|
# with ctx.redispatch_to_next():
|
||||||
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
||||||
# return ctx.wrap_tensors(out)
|
# return ctx.wrap_tensors(out)
|
||||||
def py_functionalize_impl(self, fn):
|
def py_functionalize_impl(self, fn: _F) -> _F:
|
||||||
from torch._subclasses.functional_tensor import (
|
from torch._subclasses.functional_tensor import (
|
||||||
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
||||||
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
||||||
|
|
@ -273,7 +276,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
||||||
# it to next key. This is only safe to do when PreDispatch key stack has no
|
# it to next key. This is only safe to do when PreDispatch key stack has no
|
||||||
# active modes.
|
# active modes.
|
||||||
|
|
||||||
def py_impl(self, k):
|
def py_impl(self, k: Any) -> Callable[[_F], _F]:
|
||||||
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
|
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
|
||||||
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
|
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
|
||||||
return super().py_impl(k)
|
return super().py_impl(k)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user