mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type annotations for higher order ops/flex_attention (#137065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137065 Approved by: https://github.com/drisspg, https://github.com/Skylion007 ghstack dependencies: #136826, #137043, #137049
This commit is contained in:
parent
3b8511dadf
commit
2854d157de
|
|
@ -586,7 +586,9 @@ def auto_functionalized_fake(
|
|||
**kwargs: Any,
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
with mode:
|
||||
result = auto_functionalized_dense(_mutable_op, **kwargs)
|
||||
result = auto_functionalized_dense(
|
||||
_mutable_op, _only_clone_these_tensors=None, **kwargs
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -681,7 +683,9 @@ def auto_functionalized_v2_fake(
|
|||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
||||
with mode:
|
||||
result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
|
||||
result = auto_functionalized_v2_dense(
|
||||
_mutable_op, _only_clone_these_bases=None, **kwargs
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_mutation,
|
||||
|
|
@ -12,7 +11,7 @@ from torch._higher_order_ops.utils import (
|
|||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._ops import HigherOrderOperator, OpOverload
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
make_fx,
|
||||
|
|
@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
|
|||
# scalar and create a view. We do not want that behavior in this case, so we
|
||||
# use this torchfunctionmode to override that behavior for score_mod
|
||||
# wherever we're running it.
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
def __torch_function__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Tuple[torch._C._TensorMeta, ...],
|
||||
args: Tuple[object, ...] = (),
|
||||
kwargs: Optional[Dict[str, object]] = None,
|
||||
) -> object:
|
||||
if func == torch.Tensor.__getitem__:
|
||||
index_args = pytree.tree_leaves(args[1])
|
||||
if all(isinstance(x, torch.Tensor) for x in index_args):
|
||||
|
|
@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
|
|||
|
||||
|
||||
# ---------------------------- Autograd Implementation ----------------------------
|
||||
def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
||||
def create_fw_bw_graph(
|
||||
score_mod: Callable,
|
||||
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
|
||||
other_buffers: Tuple[Tensor, ...],
|
||||
) -> Tuple[Callable, Callable]:
|
||||
# See Note:[HOP create fw_bw graph]
|
||||
|
||||
# All of these imports need to be here in order to avoid circular dependencies
|
||||
|
|
@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
|||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
|
||||
def _from_fun(t):
|
||||
def _from_fun(t: Tensor) -> Tensor:
|
||||
return torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
|
|
@ -544,8 +553,18 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
|||
)
|
||||
example_grad = _from_fun(example_flat_out)
|
||||
|
||||
def joint_f(score, b, h, m, n, example_grad, *other_buffers):
|
||||
def fw_with_masks(*args):
|
||||
def joint_f(
|
||||
score: Tensor,
|
||||
b: Tensor,
|
||||
h: Tensor,
|
||||
m: Tensor,
|
||||
n: Tensor,
|
||||
example_grad: Tensor,
|
||||
*other_buffers: Tuple[Tensor, ...],
|
||||
) -> Tuple[Tensor, ...]:
|
||||
def fw_with_masks(
|
||||
*args: Tuple[Tensor, ...]
|
||||
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
|
||||
fw_out = score_mod(*args)
|
||||
out_requires_grad = fw_out.requires_grad
|
||||
return ((fw_out,), (out_requires_grad,))
|
||||
|
|
@ -566,17 +585,17 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
|
|||
class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
fw_graph,
|
||||
joint_graph,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
ctx: Any,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
fw_graph: Callable,
|
||||
joint_graph: Callable,
|
||||
block_mask: Tuple[Any, ...],
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple[Any, ...],
|
||||
mask_mod_other_buffers: Tuple[Any, ...],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
any_buffer_requires_grad = any(
|
||||
buffer.requires_grad
|
||||
|
|
@ -620,7 +639,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
|
|||
return out, logsumexp
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out, grad_logsumexp):
|
||||
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
|
||||
fw_args = ctx.saved_tensors
|
||||
(
|
||||
query,
|
||||
|
|
@ -693,15 +712,19 @@ def flex_attention_autograd(
|
|||
block_mask: Tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
score_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with TransformGetItemToIndex():
|
||||
input_requires_grad = any(t.requires_grad for t in (query, key, value))
|
||||
if torch.is_grad_enabled() and input_requires_grad:
|
||||
example_vals = [
|
||||
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
|
||||
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
|
||||
example_vals = (
|
||||
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
|
||||
torch.zeros((), dtype=torch.int),
|
||||
torch.zeros((), dtype=torch.int),
|
||||
torch.zeros((), dtype=torch.int),
|
||||
torch.zeros((), dtype=torch.int),
|
||||
)
|
||||
fw_graph, bw_graph = create_fw_bw_graph(
|
||||
score_mod, example_vals, score_mod_other_buffers
|
||||
)
|
||||
|
|
|
|||
|
|
@ -845,7 +845,9 @@ def decompose_auto_functionalized(graph):
|
|||
# tracing a function with kwargs.
|
||||
def decomp(*flat_args):
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
|
||||
assert len(args) == 1
|
||||
mode = args[0]
|
||||
return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
|
||||
|
||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||
|
||||
|
|
@ -889,7 +891,11 @@ def decompose_auto_functionalized(graph):
|
|||
# tracing a function with kwargs.
|
||||
def decomp(*flat_args):
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
|
||||
assert len(args) == 1
|
||||
mutable_op = args[0]
|
||||
return auto_functionalized_v2_dense(
|
||||
mutable_op, only_clone_these_bases, **kwargs
|
||||
)
|
||||
|
||||
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import importlib
|
|||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Set, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -16,6 +16,9 @@ from torch._functorch.pyfunctorch import dispatch_functorch
|
|||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
|
||||
|
|
@ -99,8 +102,8 @@ class OperatorBase:
|
|||
return True
|
||||
return False
|
||||
|
||||
def py_impl(self, k):
|
||||
def inner(fn):
|
||||
def py_impl(self, k: Any) -> Callable[[_F], _F]:
|
||||
def inner(fn: _F) -> _F:
|
||||
if inspect.isclass(k) and (
|
||||
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
|
||||
):
|
||||
|
|
@ -141,7 +144,7 @@ class OperatorBase:
|
|||
# with ctx.redispatch_to_next():
|
||||
# out = ctx.functionalize(inner_f)(*args_unwrapped)
|
||||
# return ctx.wrap_tensors(out)
|
||||
def py_functionalize_impl(self, fn):
|
||||
def py_functionalize_impl(self, fn: _F) -> _F:
|
||||
from torch._subclasses.functional_tensor import (
|
||||
CppFunctionalizeAPI as _CppFunctionalizeAPI,
|
||||
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
|
||||
|
|
@ -273,7 +276,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
|
|||
# it to next key. This is only safe to do when PreDispatch key stack has no
|
||||
# active modes.
|
||||
|
||||
def py_impl(self, k):
|
||||
def py_impl(self, k: Any) -> Callable[[_F], _F]:
|
||||
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
|
||||
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
|
||||
return super().py_impl(k)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user