Add type annotations for higher order ops/flex_attention (#137065)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137065
Approved by: https://github.com/drisspg, https://github.com/Skylion007
ghstack dependencies: #136826, #137043, #137049
This commit is contained in:
chilli 2024-10-01 16:07:53 -07:00 committed by PyTorch MergeBot
parent 3b8511dadf
commit 2854d157de
4 changed files with 71 additions and 35 deletions

View File

@ -586,7 +586,9 @@ def auto_functionalized_fake(
**kwargs: Any, **kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]: ) -> Tuple[Any, Tuple[Tensor, ...]]:
with mode: with mode:
result = auto_functionalized_dense(_mutable_op, **kwargs) result = auto_functionalized_dense(
_mutable_op, _only_clone_these_tensors=None, **kwargs
)
return result return result
@ -681,7 +683,9 @@ def auto_functionalized_v2_fake(
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> Tuple[Any, Tuple[Tensor, ...]]: ) -> Tuple[Any, Tuple[Tensor, ...]]:
with mode: with mode:
result = auto_functionalized_v2_dense(_mutable_op, **kwargs) result = auto_functionalized_v2_dense(
_mutable_op, _only_clone_these_bases=None, **kwargs
)
return result return result

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import math import math
from typing import Any, Callable, Dict, Sequence, Tuple, Union from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey from torch._C import DispatchKey
from torch._higher_order_ops.utils import ( from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation, _has_potential_branch_input_mutation,
@ -12,7 +11,7 @@ from torch._higher_order_ops.utils import (
reenter_make_fx, reenter_make_fx,
UnsupportedAliasMutationException, UnsupportedAliasMutationException,
) )
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator, OpOverload
from torch._subclasses import FakeTensorMode from torch._subclasses import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ( from torch.fx.experimental.proxy_tensor import (
make_fx, make_fx,
@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode):
# scalar and create a view. We do not want that behavior in this case, so we # scalar and create a view. We do not want that behavior in this case, so we
# use this torchfunctionmode to override that behavior for score_mod # use this torchfunctionmode to override that behavior for score_mod
# wherever we're running it. # wherever we're running it.
def __torch_function__(self, func, types, args=(), kwargs=None): def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
) -> object:
if func == torch.Tensor.__getitem__: if func == torch.Tensor.__getitem__:
index_args = pytree.tree_leaves(args[1]) index_args = pytree.tree_leaves(args[1])
if all(isinstance(x, torch.Tensor) for x in index_args): if all(isinstance(x, torch.Tensor) for x in index_args):
@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode(
# ---------------------------- Autograd Implementation ---------------------------- # ---------------------------- Autograd Implementation ----------------------------
def create_fw_bw_graph(score_mod, index_values, other_buffers): def create_fw_bw_graph(
score_mod: Callable,
index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
other_buffers: Tuple[Tensor, ...],
) -> Tuple[Callable, Callable]:
# See Note:[HOP create fw_bw graph] # See Note:[HOP create fw_bw graph]
# All of these imports need to be here in order to avoid circular dependencies # All of these imports need to be here in order to avoid circular dependencies
@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
with suspend_functionalization(), disable_functional_mode(): with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing(): with disable_proxy_modes_tracing():
def _from_fun(t): def _from_fun(t: Tensor) -> Tensor:
return torch.empty_strided( return torch.empty_strided(
t.size(), t.size(),
t.stride(), t.stride(),
@ -544,8 +553,18 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
) )
example_grad = _from_fun(example_flat_out) example_grad = _from_fun(example_flat_out)
def joint_f(score, b, h, m, n, example_grad, *other_buffers): def joint_f(
def fw_with_masks(*args): score: Tensor,
b: Tensor,
h: Tensor,
m: Tensor,
n: Tensor,
example_grad: Tensor,
*other_buffers: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
def fw_with_masks(
*args: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor], Tuple[bool]]:
fw_out = score_mod(*args) fw_out = score_mod(*args)
out_requires_grad = fw_out.requires_grad out_requires_grad = fw_out.requires_grad
return ((fw_out,), (out_requires_grad,)) return ((fw_out,), (out_requires_grad,))
@ -566,17 +585,17 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
class FlexAttentionAutogradOp(torch.autograd.Function): class FlexAttentionAutogradOp(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx: Any,
query, query: Tensor,
key, key: Tensor,
value, value: Tensor,
fw_graph, fw_graph: Callable,
joint_graph, joint_graph: Callable,
block_mask, block_mask: Tuple[Any, ...],
scale, scale: float,
kernel_options, kernel_options: Dict[str, Any],
score_mod_other_buffers, score_mod_other_buffers: Tuple[Any, ...],
mask_mod_other_buffers, mask_mod_other_buffers: Tuple[Any, ...],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
any_buffer_requires_grad = any( any_buffer_requires_grad = any(
buffer.requires_grad buffer.requires_grad
@ -620,7 +639,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
return out, logsumexp return out, logsumexp
@staticmethod @staticmethod
def backward(ctx, grad_out, grad_logsumexp): def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override]
fw_args = ctx.saved_tensors fw_args = ctx.saved_tensors
( (
query, query,
@ -693,15 +712,19 @@ def flex_attention_autograd(
block_mask: Tuple, block_mask: Tuple,
scale: float, scale: float,
kernel_options: Dict[str, Any], kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (), score_mod_other_buffers: Tuple[Tensor, ...] = (),
mask_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple[Tensor, ...] = (),
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
with TransformGetItemToIndex(): with TransformGetItemToIndex():
input_requires_grad = any(t.requires_grad for t in (query, key, value)) input_requires_grad = any(t.requires_grad for t in (query, key, value))
if torch.is_grad_enabled() and input_requires_grad: if torch.is_grad_enabled() and input_requires_grad:
example_vals = [ example_vals = (
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad),
] + [torch.zeros((), dtype=torch.int) for _ in range(4)] torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
torch.zeros((), dtype=torch.int),
)
fw_graph, bw_graph = create_fw_bw_graph( fw_graph, bw_graph = create_fw_bw_graph(
score_mod, example_vals, score_mod_other_buffers score_mod, example_vals, score_mod_other_buffers
) )

View File

@ -845,7 +845,9 @@ def decompose_auto_functionalized(graph):
# tracing a function with kwargs. # tracing a function with kwargs.
def decomp(*flat_args): def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec) args, kwargs = pytree.tree_unflatten(flat_args, spec)
return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) assert len(args) == 1
mode = args[0]
return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs)
match.replace_by_example(decomp, flat_args, run_functional_passes=False) match.replace_by_example(decomp, flat_args, run_functional_passes=False)
@ -889,7 +891,11 @@ def decompose_auto_functionalized(graph):
# tracing a function with kwargs. # tracing a function with kwargs.
def decomp(*flat_args): def decomp(*flat_args):
args, kwargs = pytree.tree_unflatten(flat_args, spec) args, kwargs = pytree.tree_unflatten(flat_args, spec)
return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) assert len(args) == 1
mutable_op = args[0]
return auto_functionalized_v2_dense(
mutable_op, only_clone_these_bases, **kwargs
)
match.replace_by_example(decomp, flat_args, run_functional_passes=False) match.replace_by_example(decomp, flat_args, run_functional_passes=False)

View File

@ -6,7 +6,7 @@ import importlib
import inspect import inspect
import sys import sys
import types import types
from typing import Any, Callable, Dict, List, Set, Type, Union from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -16,6 +16,9 @@ from torch._functorch.pyfunctorch import dispatch_functorch
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
_F = TypeVar("_F", bound=Callable[..., Any])
# Query `hasattr` only once. # Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -99,8 +102,8 @@ class OperatorBase:
return True return True
return False return False
def py_impl(self, k): def py_impl(self, k: Any) -> Callable[[_F], _F]:
def inner(fn): def inner(fn: _F) -> _F:
if inspect.isclass(k) and ( if inspect.isclass(k) and (
issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor)
): ):
@ -141,7 +144,7 @@ class OperatorBase:
# with ctx.redispatch_to_next(): # with ctx.redispatch_to_next():
# out = ctx.functionalize(inner_f)(*args_unwrapped) # out = ctx.functionalize(inner_f)(*args_unwrapped)
# return ctx.wrap_tensors(out) # return ctx.wrap_tensors(out)
def py_functionalize_impl(self, fn): def py_functionalize_impl(self, fn: _F) -> _F:
from torch._subclasses.functional_tensor import ( from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI as _CppFunctionalizeAPI, CppFunctionalizeAPI as _CppFunctionalizeAPI,
FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
@ -273,7 +276,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC):
# it to next key. This is only safe to do when PreDispatch key stack has no # it to next key. This is only safe to do when PreDispatch key stack has no
# active modes. # active modes.
def py_impl(self, k): def py_impl(self, k: Any) -> Callable[[_F], _F]:
if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k):
self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
return super().py_impl(k) return super().py_impl(k)