Type exposed_in decorator (#141894)

Test Plan:
- lintrunner
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141894
Approved by: https://github.com/albanD
This commit is contained in:
rzou 2024-12-02 13:22:01 -08:00 committed by PyTorch MergeBot
parent 7a806a839d
commit ac600fdce6
4 changed files with 12 additions and 7 deletions

View File

@ -26,7 +26,7 @@ def custom_op(
mutates_args: Union[str, Iterable[str]],
device_types: device_types_t = None,
schema: Optional[str] = None,
) -> Callable:
) -> Any:
"""Wraps a function into custom operator.
Reasons why you may want to create a custom op include:

View File

@ -1,6 +1,6 @@
import contextlib
import threading
from typing import Callable, Generator, Iterable, Optional, Union
from typing import Any, Callable, Generator, Iterable, Optional, Union
from .custom_ops import custom_op
from .infer_schema import infer_schema
@ -92,7 +92,7 @@ def triton_op(
"""
def dec(fn: Callable) -> Callable:
def dec(fn: Callable) -> Any:
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
# Optimization: we're passing regular Tensors into the triton kernel, so
# no need to go through HOP dispatch

View File

@ -115,7 +115,7 @@ def _vmap_for_bhqkv(
]
for dims in dimensions:
fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims)
fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) # type: ignore[arg-type]
return fn

View File

@ -1,4 +1,9 @@
# mypy: allow-untyped-defs
from typing import Callable, TypeVar
F = TypeVar("F")
# Allows one to expose an API in a private submodule publicly as per the definition
# in PyTorch's public api policy.
#
@ -7,8 +12,8 @@
# may not be very robust because it's not clear what __module__ is used for.
# However, both numpy and jax overwrite the __module__ attribute of their APIs
# without problem, so it seems fine.
def exposed_in(module):
def wrapper(fn):
def exposed_in(module: str) -> Callable[[F], F]:
def wrapper(fn: F) -> F:
fn.__module__ = module
return fn