mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Type exposed_in decorator (#141894)
Test Plan: - lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/141894 Approved by: https://github.com/albanD
This commit is contained in:
parent
7a806a839d
commit
ac600fdce6
|
|
@ -26,7 +26,7 @@ def custom_op(
|
|||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Callable:
|
||||
) -> Any:
|
||||
"""Wraps a function into custom operator.
|
||||
|
||||
Reasons why you may want to create a custom op include:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import contextlib
|
||||
import threading
|
||||
from typing import Callable, Generator, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Generator, Iterable, Optional, Union
|
||||
|
||||
from .custom_ops import custom_op
|
||||
from .infer_schema import infer_schema
|
||||
|
|
@ -92,7 +92,7 @@ def triton_op(
|
|||
|
||||
"""
|
||||
|
||||
def dec(fn: Callable) -> Callable:
|
||||
def dec(fn: Callable) -> Any:
|
||||
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
# Optimization: we're passing regular Tensors into the triton kernel, so
|
||||
# no need to go through HOP dispatch
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def _vmap_for_bhqkv(
|
|||
]
|
||||
|
||||
for dims in dimensions:
|
||||
fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims)
|
||||
fn = torch.vmap(fn, in_dims=prefix + dims + suffix, out_dims=out_dims) # type: ignore[arg-type]
|
||||
return fn
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
|
||||
F = TypeVar("F")
|
||||
|
||||
|
||||
# Allows one to expose an API in a private submodule publicly as per the definition
|
||||
# in PyTorch's public api policy.
|
||||
#
|
||||
|
|
@ -7,8 +12,8 @@
|
|||
# may not be very robust because it's not clear what __module__ is used for.
|
||||
# However, both numpy and jax overwrite the __module__ attribute of their APIs
|
||||
# without problem, so it seems fine.
|
||||
def exposed_in(module):
|
||||
def wrapper(fn):
|
||||
def exposed_in(module: str) -> Callable[[F], F]:
|
||||
def wrapper(fn: F) -> F:
|
||||
fn.__module__ = module
|
||||
return fn
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user