Add type annotations for higher order ops/flex_attention (#137065)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137065
Approved by: https://github.com/drisspg, https://github.com/Skylion007
ghstack dependencies: #136826, #137043, #137049
This commit is contained in:
chilli 2024-10-01 16:07:53 -07:00 committed by PyTorch MergeBot
parent 3b8511dadf
commit 2854d157de
4 changed files with 71 additions and 35 deletions

View File

@ -586,7 +586,9 @@ def auto_functionalized_fake(
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
with mode:
result = auto_functionalized_dense(_mutable_op, **kwargs)
result = auto_functionalized_dense(
_mutable_op, _only_clone_these_tensors=None, **kwargs
)
return result
@ -681,7 +683,9 @@ def auto_functionalized_v2_fake(
**kwargs: Dict[str, Any],
) -> Tuple[Any, Tuple[Tensor, ...]]:
with mode:
result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
result = auto_functionalized_v2_dense(
_mutable_op, _only_clone_these_bases=None, **kwargs
)
return result

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import math
from typing import Any, Callable, Dict, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
@ -12,7 +11,7 @@ from torch._higher_order_ops.utils import (
reenter_make_fx,
UnsupportedAliasMutationException,
)
from torch._ops import HigherOrderOperator
from torch._ops import HigherOrderOperator, OpOverload
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
make_fx,
@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
# scalar and create a view. We do not want that behavior in this case, so we
# use this torchfunctionmode to override that behavior for score_mod
# wherever we're running it.
def __torch_function__(self, func, types, args=(), kwargs=None):
def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
) -> object:
if func == torch.Tensor.__getitem__:
index_args = pytree.tree_leaves(args[1])
if all(isinstance(x, torch.Tensor) for x in index_args):
@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
# ---------------------------- Autograd Implementation ----------------------------
def create_fw_bw_graph(score_mod, index_values, other_buffers):
def create_fw_bw_graph(
score_mod: Callable,
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
other_buffers: Tuple[Tensor, ...],
) -> Tuple[Callable, Callable]:
# See Note:[HOP create fw_bw graph]
# All of these imports need to be here in order to avoid circular dependencies
@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
def _from_fun(t):
def _from_fun(t: Tensor) -> Tensor:
return torch.empty_strided(
t.size(),
t.stride(),
@ -544,8 +553,18 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
)
example_grad = _from_fun(example_flat_out)
def joint_f(score, b, h, m, n, example_grad, *other_buffers):
def fw_with_masks(*args):
def joint_f(
score: Tensor,
b: Tensor,
h: Tensor,
m: Tensor,
n: Tensor,
example_grad: Tensor,
*other_buffers: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
def fw_with_masks(
*args: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
fw_out = score_mod(*args)
out_requires_grad = fw_out.requires_grad
return ((fw_out,), (out_requires_grad,))
@ -566,17 +585,17 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
class FlexAttentionAutogradOp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
query,
key,
value,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
ctx: Any,
query: Tensor,
key: Tensor,
value: Tensor,
fw_graph: Callable,
joint_graph: Callable,
block_mask: Tuple[Any, ...],
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple[Any, ...],
mask_mod_other_buffers: Tuple[Any, ...],
) -> Tuple[torch.Tensor, torch.Tensor]:
any_buffer_requires_grad = any(
buffer.requires_grad
@ -620,7 +639,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
return out, logsumexp
@staticmethod
def backward(ctx, grad_out, grad_logsumexp):
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
fw_args = ctx.saved_tensors
(
query,
@ -693,15 +712,19 @@ def flex_attention_autograd(
block_mask: Tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
score_mod_other_buffers: Tuple[Tensor, ...] = (),
mask_mod_other_buffers: Tuple[Tensor, ...] = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
with TransformGetItemToIndex():
input_requires_grad = any(t.requires_grad for t in (query, key, value))
if torch.is_grad_enabled() and input_requires_grad:
example_vals = [
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
example_vals = (
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
)
fw_graph, bw_graph = create_fw_bw_graph(
score_mod, example_vals, score_mod_other_buffers
)

View File

@ -845,7 +845,9 @@ def decompose_auto_functionalized(graph):
# tracing a function with kwargs.
def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec)
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs)
assert len(args) == 1
mode = args[0]
return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
@ -889,7 +891,11 @@ def decompose_auto_functionalized(graph):
# tracing a function with kwargs.
def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec)
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs)
assert len(args) == 1
mutable_op = args[0]
return auto_functionalized_v2_dense(
mutable_op, only_clone_these_bases, **kwargs
)
match.replace_by_example(decomp, flat_args, run_functional_passes=False)

View File

@ -6,7 +6,7 @@ import importlib
import inspect
import sys
import types
from typing import Any, Callable, Dict, List, Set, Type, Union
from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
import torch
import torch.utils._pytree as pytree
@ -16,6 +16,9 @@ from torch._functorch.pyfunctorch import dispatch_functorch
from torch.utils._python_dispatch import TorchDispatchMode
_F = TypeVar("_F", bound=Callable[..., Any])
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -99,8 +102,8 @@ class OperatorBase:
return True
return False
def py_impl(self, k):
def inner(fn):
def py_impl(self, k: Any) -> Callable[[_F], _F]:
def inner(fn: _F) -> _F:
if inspect.isclass(k) and (
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
):
@ -141,7 +144,7 @@ class OperatorBase:
# with ctx.redispatch_to_next():
# out = ctx.functionalize(inner_f)(*args_unwrapped)
# return ctx.wrap_tensors(out)
def py_functionalize_impl(self, fn):
def py_functionalize_impl(self, fn: _F) -> _F:
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI as _CppFunctionalizeAPI,
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
@ -273,7 +276,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
# it to next key. This is only safe to do when PreDispatch key stack has no
# active modes.
def py_impl(self, k):
def py_impl(self, k: Any) -> Callable[[_F], _F]:
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
return super().py_impl(k)