mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131578 Approved by: https://github.com/oulgen, https://github.com/zou3519 ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577
22 lines
726 B
Python
22 lines
726 B
Python
# mypy: allow-untyped-defs
|
|
|
|
from typing import Callable, TypeVar
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
# Allows one to expose an API in a private submodule publicly as per the definition
|
|
# in PyTorch's public api policy.
|
|
#
|
|
# It is a temporary solution while we figure out if it should be the long-term solution
|
|
# or if we should amend PyTorch's public api policy. The concern is that this approach
|
|
# may not be very robust because it's not clear what __module__ is used for.
|
|
# However, both numpy and jax overwrite the __module__ attribute of their APIs
|
|
# without problem, so it seems fine.
|
|
def exposed_in(module: str) -> Callable[[_T], _T]:
|
|
def wrapper(fn: _T) -> _T:
|
|
fn.__module__ = module
|
|
return fn
|
|
|
|
return wrapper
|