diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 232981f1f01..ac06b9c8229 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -586,7 +586,9 @@ def auto_functionalized_fake( **kwargs: Any, ) -> Tuple[Any, Tuple[Tensor, ...]]: with mode: - result = auto_functionalized_dense(_mutable_op, **kwargs) + result = auto_functionalized_dense( + _mutable_op, _only_clone_these_tensors=None, **kwargs + ) return result @@ -681,7 +683,9 @@ def auto_functionalized_v2_fake( **kwargs: Dict[str, Any], ) -> Tuple[Any, Tuple[Tensor, ...]]: with mode: - result = auto_functionalized_v2_dense(_mutable_op, **kwargs) + result = auto_functionalized_v2_dense( + _mutable_op, _only_clone_these_bases=None, **kwargs + ) return result diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index dbd2020b48f..94e39a96ca7 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,10 +1,9 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs import math -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree +from torch import Tensor from torch._C import DispatchKey from torch._higher_order_ops.utils import ( _has_potential_branch_input_mutation, @@ -12,7 +11,7 @@ from torch._higher_order_ops.utils import ( reenter_make_fx, UnsupportedAliasMutationException, ) -from torch._ops import HigherOrderOperator +from torch._ops import HigherOrderOperator, OpOverload from torch._subclasses import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( make_fx, @@ -77,7 +76,13 @@ class TransformGetItemToIndex(TorchFunctionMode): # scalar and create a view. We do not want that behavior in this case, so we # use this torchfunctionmode to override that behavior for score_mod # wherever we're running it. - def __torch_function__(self, func, types, args=(), kwargs=None): + def __torch_function__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...] = (), + kwargs: Optional[Dict[str, object]] = None, + ) -> object: if func == torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): @@ -485,7 +490,11 @@ def flex_attention_fake_tensor_mode( # ---------------------------- Autograd Implementation ---------------------------- -def create_fw_bw_graph(score_mod, index_values, other_buffers): +def create_fw_bw_graph( + score_mod: Callable, + index_values: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor], + other_buffers: Tuple[Tensor, ...], +) -> Tuple[Callable, Callable]: # See Note:[HOP create fw_bw graph] # All of these imports need to be here in order to avoid circular dependencies @@ -508,7 +517,7 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): - def _from_fun(t): + def _from_fun(t: Tensor) -> Tensor: return torch.empty_strided( t.size(), t.stride(), @@ -544,8 +553,18 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers): ) example_grad = _from_fun(example_flat_out) - def joint_f(score, b, h, m, n, example_grad, *other_buffers): - def fw_with_masks(*args): + def joint_f( + score: Tensor, + b: Tensor, + h: Tensor, + m: Tensor, + n: Tensor, + example_grad: Tensor, + *other_buffers: Tuple[Tensor, ...], + ) -> Tuple[Tensor, ...]: + def fw_with_masks( + *args: Tuple[Tensor, ...] + ) -> Tuple[Tuple[Tensor], Tuple[bool]]: fw_out = score_mod(*args) out_requires_grad = fw_out.requires_grad return ((fw_out,), (out_requires_grad,)) @@ -566,17 +585,17 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers): class FlexAttentionAutogradOp(torch.autograd.Function): @staticmethod def forward( - ctx, - query, - key, - value, - fw_graph, - joint_graph, - block_mask, - scale, - kernel_options, - score_mod_other_buffers, - mask_mod_other_buffers, + ctx: Any, + query: Tensor, + key: Tensor, + value: Tensor, + fw_graph: Callable, + joint_graph: Callable, + block_mask: Tuple[Any, ...], + scale: float, + kernel_options: Dict[str, Any], + score_mod_other_buffers: Tuple[Any, ...], + mask_mod_other_buffers: Tuple[Any, ...], ) -> Tuple[torch.Tensor, torch.Tensor]: any_buffer_requires_grad = any( buffer.requires_grad @@ -620,7 +639,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function): return out, logsumexp @staticmethod - def backward(ctx, grad_out, grad_logsumexp): + def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Optional[Tensor], ...]: # type: ignore[override] fw_args = ctx.saved_tensors ( query, @@ -693,15 +712,19 @@ def flex_attention_autograd( block_mask: Tuple, scale: float, kernel_options: Dict[str, Any], - score_mod_other_buffers: Tuple = (), - mask_mod_other_buffers: Tuple = (), + score_mod_other_buffers: Tuple[Tensor, ...] = (), + mask_mod_other_buffers: Tuple[Tensor, ...] = (), ) -> Tuple[torch.Tensor, torch.Tensor]: with TransformGetItemToIndex(): input_requires_grad = any(t.requires_grad for t in (query, key, value)) if torch.is_grad_enabled() and input_requires_grad: - example_vals = [ - torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) - ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + example_vals = ( + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), + torch.zeros((), dtype=torch.int), + ) fw_graph, bw_graph = create_fw_bw_graph( score_mod, example_vals, score_mod_other_buffers ) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index ae82ca14b52..06d05983988 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -845,7 +845,9 @@ def decompose_auto_functionalized(graph): # tracing a function with kwargs. def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) - return auto_functionalized_dense(*args, only_clone_these_tensors, **kwargs) + assert len(args) == 1 + mode = args[0] + return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs) match.replace_by_example(decomp, flat_args, run_functional_passes=False) @@ -889,7 +891,11 @@ def decompose_auto_functionalized(graph): # tracing a function with kwargs. def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) - return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) + assert len(args) == 1 + mutable_op = args[0] + return auto_functionalized_v2_dense( + mutable_op, only_clone_these_bases, **kwargs + ) match.replace_by_example(decomp, flat_args, run_functional_passes=False) diff --git a/torch/_ops.py b/torch/_ops.py index a4171013bec..a823d44142f 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -6,7 +6,7 @@ import importlib import inspect import sys import types -from typing import Any, Callable, Dict, List, Set, Type, Union +from typing import Any, Callable, Dict, List, Set, Type, TypeVar, Union import torch import torch.utils._pytree as pytree @@ -16,6 +16,9 @@ from torch._functorch.pyfunctorch import dispatch_functorch from torch.utils._python_dispatch import TorchDispatchMode +_F = TypeVar("_F", bound=Callable[..., Any]) + + # Query `hasattr` only once. _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") @@ -99,8 +102,8 @@ class OperatorBase: return True return False - def py_impl(self, k): - def inner(fn): + def py_impl(self, k: Any) -> Callable[[_F], _F]: + def inner(fn: _F) -> _F: if inspect.isclass(k) and ( issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) ): @@ -141,7 +144,7 @@ class OperatorBase: # with ctx.redispatch_to_next(): # out = ctx.functionalize(inner_f)(*args_unwrapped) # return ctx.wrap_tensors(out) - def py_functionalize_impl(self, fn): + def py_functionalize_impl(self, fn: _F) -> _F: from torch._subclasses.functional_tensor import ( CppFunctionalizeAPI as _CppFunctionalizeAPI, FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, @@ -273,7 +276,7 @@ class HigherOrderOperator(OperatorBase, abc.ABC): # it to next key. This is only safe to do when PreDispatch key stack has no # active modes. - def py_impl(self, k): + def py_impl(self, k: Any) -> Callable[[_F], _F]: if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) return super().py_impl(k)