diff --git a/test/test_utils.py b/test/test_utils.py index 22910af8913..080afe76159 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.utils.cpp_extension import torch.utils.data from torch._utils import try_import +from torch._utils_internal import deprecated from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -61,6 +62,9 @@ HAS_CUDA = torch.cuda.is_available() from torch.testing._internal.common_utils import run_tests, TestCase +# mypy: disable-error-code="name-defined" + + class RandomDatasetMock(torch.utils.data.Dataset): def __getitem__(self, index): return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) @@ -1197,5 +1201,20 @@ class TestTryImport(TestCase): self.assertIsNone(missing_module) +@deprecated() +def _deprecated_api(x, y=15): + return x + y + + +class TestDeprecate(TestCase): + def test_deprecated(self): + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, 2) # noqa: F821 + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, y=2) # noqa: F821 + _deprecated_api(1, 2) + _deprecated_api(1, y=2) + + if __name__ == "__main__": run_tests() diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index b5da1941450..fd8b8f08f8b 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,6 +4,7 @@ import logging import os import sys import tempfile +import typing_extensions from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -282,3 +283,54 @@ def record_chromium_event_internal( def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12(): return True + + +def deprecated(): + """ + When we deprecate a function that might still be in use, we make it internal + by adding a leading underscore. This decorator is used with a private function, + and creates a public alias without the leading underscore, but has a deprecation + warning. This tells users "THIS FUNCTION IS DEPRECATED, please use something else" + without breaking them, however, if they still really really want to use the + deprecated function without the warning, they can do so by using the internal + function name. + """ + + def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + # Validate naming convention – single leading underscore, not dunder + if not (func.__name__.startswith("_")): + raise ValueError( + "@deprecate must decorate a function whose name " + "starts with a single leading underscore (e.g. '_foo') as the api should be considered internal for deprecation." + ) + + public_name = func.__name__[1:] # drop exactly one leading underscore + module = sys.modules[func.__module__] + + # Don't clobber an existing symbol accidentally. + if hasattr(module, public_name): + raise RuntimeError( + f"Cannot create alias '{public_name}' -> symbol already exists in {module.__name__}. \ + Please rename it or consult a pytorch developer on what to do" + ) + + warning_msg = f"{func.__name__[1:]} is DEPRECATED, please consider using an alternative API(s). " + + # public deprecated alias + alias = typing_extensions.deprecated( + warning_msg, category=UserWarning, stacklevel=1 + )(func) + + alias.__name__ = public_name + + # Adjust qualname if nested inside a class or another function + if "." in func.__qualname__: + alias.__qualname__ = func.__qualname__.rsplit(".", 1)[0] + "." + public_name + else: + alias.__qualname__ = public_name + + setattr(module, public_name, alias) + + return func + + return decorator