mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][CI][Easy] Run lintrunner on generated .pyi stub files (#150732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150732 Approved by: https://github.com/malfet, https://github.com/cyyever, https://github.com/aorenste
This commit is contained in:
parent
0a7eef140b
commit
7ae204c3b6
6
.github/scripts/lintrunner.sh
vendored
6
.github/scripts/lintrunner.sh
vendored
|
|
@ -31,6 +31,9 @@ python3 -m tools.pyi.gen_pyi \
|
||||||
--deprecated-functions-path "tools/autograd/deprecated.yaml"
|
--deprecated-functions-path "tools/autograd/deprecated.yaml"
|
||||||
python3 torch/utils/data/datapipes/gen_pyi.py
|
python3 torch/utils/data/datapipes/gen_pyi.py
|
||||||
|
|
||||||
|
# Also check generated pyi files
|
||||||
|
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
||||||
|
|
||||||
RC=0
|
RC=0
|
||||||
# Run lintrunner on all files
|
# Run lintrunner on all files
|
||||||
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
||||||
|
|
@ -41,6 +44,9 @@ if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS}
|
||||||
RC=1
|
RC=1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Unstage temporally added pyi files
|
||||||
|
find torch -name '*.pyi' -exec git restore --staged -- "{}" +
|
||||||
|
|
||||||
# Use jq to massage the JSON lint output into GitHub Actions workflow commands.
|
# Use jq to massage the JSON lint output into GitHub Actions workflow commands.
|
||||||
jq --raw-output \
|
jq --raw-output \
|
||||||
'"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \
|
'"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \
|
||||||
|
|
|
||||||
|
|
@ -212,6 +212,11 @@ select = [
|
||||||
"__init__.py" = [
|
"__init__.py" = [
|
||||||
"F401",
|
"F401",
|
||||||
]
|
]
|
||||||
|
"*.pyi" = [
|
||||||
|
"PYI011", # typed-argument-default-in-stub
|
||||||
|
"PYI021", # docstring-in-stub
|
||||||
|
"PYI053", # string-or-bytes-too-long
|
||||||
|
]
|
||||||
"functorch/notebooks/**" = [
|
"functorch/notebooks/**" = [
|
||||||
"F401",
|
"F401",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -113,9 +113,9 @@ class MpsMemoryLeakCheck:
|
||||||
self.caching_allocator_before = torch.mps.current_allocated_memory()
|
self.caching_allocator_before = torch.mps.current_allocated_memory()
|
||||||
self.driver_before = torch.mps.driver_allocated_memory()
|
self.driver_before = torch.mps.driver_allocated_memory()
|
||||||
|
|
||||||
def __exit__(self, exec_type, exec_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
# Don't check for leaks if an exception was thrown
|
# Don't check for leaks if an exception was thrown
|
||||||
if exec_type is not None:
|
if exc_type is not None:
|
||||||
return
|
return
|
||||||
# Compares caching allocator before/after statistics
|
# Compares caching allocator before/after statistics
|
||||||
# An increase in allocated memory is a discrepancy indicating a possible memory leak
|
# An increase in allocated memory is a discrepancy indicating a possible memory leak
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,11 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
# mypy: disable-error-code="type-arg"
|
# mypy: disable-error-code="type-arg"
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
# ruff: noqa: F401,PYI054
|
||||||
|
|
||||||
import builtins
|
from collections.abc import Sequence
|
||||||
from types import EllipsisType
|
from types import EllipsisType
|
||||||
from typing import (
|
from typing import Any, Callable, Literal, overload, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ContextManager,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
|
||||||
NamedTuple,
|
|
||||||
overload,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import (
|
from torch import (
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
# mypy: disable-error-code="type-arg"
|
# mypy: disable-error-code="type-arg"
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
# ruff: noqa: F401
|
||||||
|
|
||||||
import builtins
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import EllipsisType
|
from types import EllipsisType
|
||||||
|
|
@ -10,17 +11,13 @@ from typing import (
|
||||||
Any,
|
Any,
|
||||||
AnyStr,
|
AnyStr,
|
||||||
Callable,
|
Callable,
|
||||||
ContextManager,
|
|
||||||
Generic,
|
Generic,
|
||||||
IO,
|
IO,
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
Literal,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
overload,
|
overload,
|
||||||
Protocol,
|
Protocol,
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
Sequence,
|
|
||||||
SupportsIndex,
|
SupportsIndex,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
@ -71,15 +68,15 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
# This module is defined in torch/csrc/Module.cpp
|
# This module is defined in torch/csrc/Module.cpp
|
||||||
|
|
||||||
K = TypeVar("K")
|
K = TypeVar("K") # noqa: PYI001
|
||||||
T = TypeVar("T")
|
T = TypeVar("T") # noqa: PYI001
|
||||||
S = TypeVar("S", bound=torch.Tensor)
|
S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P") # noqa: PYI001
|
||||||
ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant)
|
R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001
|
||||||
_T_co = TypeVar("_T_co", covariant=True)
|
T_co = TypeVar("T_co", covariant=True) # noqa: PYI001
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class _NestedSequence(Protocol[_T_co]):
|
class _NestedSequence(Protocol[T_co]):
|
||||||
"""A protocol for representing nested sequences.
|
"""A protocol for representing nested sequences.
|
||||||
|
|
||||||
References::
|
References::
|
||||||
|
|
@ -88,10 +85,10 @@ class _NestedSequence(Protocol[_T_co]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __len__(self, /) -> _int: ...
|
def __len__(self, /) -> _int: ...
|
||||||
def __getitem__(self, index: _int, /) -> _T_co | _NestedSequence[_T_co]: ...
|
def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ...
|
||||||
def __contains__(self, x: object, /) -> _bool: ...
|
def __contains__(self, x: object, /) -> _bool: ...
|
||||||
def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
|
def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
|
||||||
def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
|
def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
|
||||||
def count(self, value: Any, /) -> _int: ...
|
def count(self, value: Any, /) -> _int: ...
|
||||||
def index(self, value: Any, /) -> _int: ...
|
def index(self, value: Any, /) -> _int: ...
|
||||||
|
|
||||||
|
|
@ -146,7 +143,7 @@ class Stream:
|
||||||
def record_event(self, event: Event | None = None) -> Event: ...
|
def record_event(self, event: Event | None = None) -> Event: ...
|
||||||
def __hash__(self) -> _int: ...
|
def __hash__(self) -> _int: ...
|
||||||
def __eq__(self, other: object) -> _bool: ...
|
def __eq__(self, other: object) -> _bool: ...
|
||||||
def __enter__(self) -> Stream: ...
|
def __enter__(self) -> Self: ...
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/Event.cpp
|
# Defined in torch/csrc/Event.cpp
|
||||||
|
|
@ -321,14 +318,14 @@ def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
|
||||||
def unify_type_list(types: list[JitType]) -> JitType: ...
|
def unify_type_list(types: list[JitType]) -> JitType: ...
|
||||||
def _freeze_module(
|
def _freeze_module(
|
||||||
module: ScriptModule,
|
module: ScriptModule,
|
||||||
preserved_attrs: list[str] = [],
|
preserved_attrs: list[str] = ...,
|
||||||
freeze_interfaces: _bool = True,
|
freeze_interfaces: _bool = True,
|
||||||
preserveParameters: _bool = True,
|
preserveParameters: _bool = True,
|
||||||
) -> ScriptModule: ...
|
) -> ScriptModule: ...
|
||||||
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
|
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
|
||||||
def _jit_pass_optimize_for_inference(
|
def _jit_pass_optimize_for_inference(
|
||||||
module: torch.jit.ScriptModule,
|
module: torch.jit.ScriptModule,
|
||||||
other_methods: list[str] = [],
|
other_methods: list[str] = ...,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
|
||||||
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
|
||||||
|
|
@ -759,7 +756,7 @@ class AliasDb: ...
|
||||||
|
|
||||||
class _InsertPoint:
|
class _InsertPoint:
|
||||||
def __enter__(self) -> None: ...
|
def __enter__(self) -> None: ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/ir/ir.h
|
# Defined in torch/csrc/jit/ir/ir.h
|
||||||
class Use:
|
class Use:
|
||||||
|
|
@ -1078,8 +1075,8 @@ class LiteScriptModule:
|
||||||
def run_method(self, method_name: str, *input): ...
|
def run_method(self, method_name: str, *input): ...
|
||||||
|
|
||||||
# NOTE: switch to collections.abc.Callable in python 3.9
|
# NOTE: switch to collections.abc.Callable in python 3.9
|
||||||
class ScriptFunction(Generic[P, ReturnVal]):
|
class ScriptFunction(Generic[P, R]):
|
||||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
|
||||||
def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ...
|
def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ...
|
||||||
def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ...
|
def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ...
|
||||||
@property
|
@property
|
||||||
|
|
@ -1092,9 +1089,9 @@ class ScriptFunction(Generic[P, ReturnVal]):
|
||||||
def qualified_name(self) -> str: ...
|
def qualified_name(self) -> str: ...
|
||||||
|
|
||||||
# NOTE: switch to collections.abc.Callable in python 3.9
|
# NOTE: switch to collections.abc.Callable in python 3.9
|
||||||
class ScriptMethod(Generic[P, ReturnVal]):
|
class ScriptMethod(Generic[P, R]):
|
||||||
graph: Graph
|
graph: Graph
|
||||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
|
||||||
@property
|
@property
|
||||||
def owner(self) -> ScriptModule: ...
|
def owner(self) -> ScriptModule: ...
|
||||||
@property
|
@property
|
||||||
|
|
@ -1375,7 +1372,7 @@ class _LinalgBackend:
|
||||||
# members. There is a chance this is due to a recent change in the semantics
|
# members. There is a chance this is due to a recent change in the semantics
|
||||||
# of enum membership. If so, use `member = value` to mark an enum member,
|
# of enum membership. If so, use `member = value` to mark an enum member,
|
||||||
# instead of `member: type`
|
# instead of `member: type`
|
||||||
class BatchNormBackend(Enum): ... # type: ignore[misc]
|
class BatchNormBackend(Enum): ... # type: ignore[misc]
|
||||||
|
|
||||||
def _get_blas_preferred_backend() -> _BlasBackend: ...
|
def _get_blas_preferred_backend() -> _BlasBackend: ...
|
||||||
def _set_blas_preferred_backend(arg: _BlasBackend): ...
|
def _set_blas_preferred_backend(arg: _BlasBackend): ...
|
||||||
|
|
@ -1400,7 +1397,7 @@ class _ROCmFABackend:
|
||||||
# There is a chance this is due to a recent change in the semantics of enum
|
# There is a chance this is due to a recent change in the semantics of enum
|
||||||
# membership. If so, use `member = value` to mark an enum member, instead of
|
# membership. If so, use `member = value` to mark an enum member, instead of
|
||||||
# `member: type`
|
# `member: type`
|
||||||
class ConvBackend(Enum): ... # type: ignore[misc]
|
class ConvBackend(Enum): ... # type: ignore[misc]
|
||||||
|
|
||||||
class Tag(Enum):
|
class Tag(Enum):
|
||||||
${tag_attributes}
|
${tag_attributes}
|
||||||
|
|
@ -1481,9 +1478,7 @@ def _get_function_stack_at(idx: _int) -> Any: ...
|
||||||
def _len_torch_function_stack() -> _int: ...
|
def _len_torch_function_stack() -> _int: ...
|
||||||
def _set_torch_dispatch_mode(cls: Any) -> None: ...
|
def _set_torch_dispatch_mode(cls: Any) -> None: ...
|
||||||
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
|
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
|
||||||
def _pop_torch_dispatch_stack(
|
def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ...
|
||||||
mode_key: _TorchDispatchModeKey | None = None,
|
|
||||||
) -> Any: ...
|
|
||||||
def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ...
|
def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ...
|
||||||
def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ...
|
def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ...
|
||||||
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
|
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
|
||||||
|
|
@ -1494,42 +1489,42 @@ def _activate_gpu_trace() -> None: ...
|
||||||
class _DisableTorchDispatch:
|
class _DisableTorchDispatch:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _EnableTorchFunction:
|
class _EnableTorchFunction:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _EnablePythonDispatcher:
|
class _EnablePythonDispatcher:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _DisablePythonDispatcher:
|
class _DisablePythonDispatcher:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _EnablePreDispatch:
|
class _EnablePreDispatch:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _DisableFuncTorch:
|
class _DisableFuncTorch:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _DisableAutocast:
|
class _DisableAutocast:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _InferenceMode:
|
class _InferenceMode:
|
||||||
def __init__(self, enabled: _bool) -> None: ...
|
def __init__(self, enabled: _bool) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
def _set_autograd_fallback_mode(mode: str) -> None: ...
|
def _set_autograd_fallback_mode(mode: str) -> None: ...
|
||||||
def _get_autograd_fallback_mode() -> str: ...
|
def _get_autograd_fallback_mode() -> str: ...
|
||||||
|
|
@ -1783,32 +1778,32 @@ def _commit_update(a: Tensor) -> None: ...
|
||||||
class _ExcludeDispatchKeyGuard:
|
class _ExcludeDispatchKeyGuard:
|
||||||
def __init__(self, keyset: DispatchKeySet) -> None: ...
|
def __init__(self, keyset: DispatchKeySet) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _IncludeDispatchKeyGuard:
|
class _IncludeDispatchKeyGuard:
|
||||||
def __init__(self, k: DispatchKey) -> None: ...
|
def __init__(self, k: DispatchKey) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _ForceDispatchKeyGuard:
|
class _ForceDispatchKeyGuard:
|
||||||
def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ...
|
def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _PreserveDispatchKeyGuard:
|
class _PreserveDispatchKeyGuard:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _AutoDispatchBelowAutograd:
|
class _AutoDispatchBelowAutograd:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
class _AutoDispatchBelowADInplaceOrView:
|
class _AutoDispatchBelowADInplaceOrView:
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
|
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
|
||||||
def _dispatch_get_registrations_for_dispatch_key(
|
def _dispatch_get_registrations_for_dispatch_key(
|
||||||
|
|
@ -1827,18 +1822,16 @@ class _TorchDispatchModeKey(Enum):
|
||||||
class _SetExcludeDispatchKeyGuard:
|
class _SetExcludeDispatchKeyGuard:
|
||||||
def __init__(self, k: DispatchKey, enabled: _bool) -> None: ...
|
def __init__(self, k: DispatchKey, enabled: _bool) -> None: ...
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/utils/schema_info.h
|
# Defined in torch/csrc/utils/schema_info.h
|
||||||
|
|
||||||
class _SchemaInfo:
|
class _SchemaInfo:
|
||||||
def __init__(self, schema: _int) -> None: ...
|
def __init__(self, schema: _int) -> None: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def is_mutable(self) -> _bool: ...
|
def is_mutable(self) -> _bool: ...
|
||||||
@overload
|
@overload
|
||||||
def is_mutable(self, name: str) -> _bool: ...
|
def is_mutable(self, name: str) -> _bool: ...
|
||||||
|
|
||||||
def has_argument(self, name: str) -> _bool: ...
|
def has_argument(self, name: str) -> _bool: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/utils/init.cpp
|
# Defined in torch/csrc/utils/init.cpp
|
||||||
|
|
@ -2431,7 +2424,7 @@ def _create_graph_by_tracing(
|
||||||
strict: Any,
|
strict: Any,
|
||||||
force_outplace: Any,
|
force_outplace: Any,
|
||||||
self: Any = None,
|
self: Any = None,
|
||||||
argument_names: list[str] = [],
|
argument_names: list[str] = ...,
|
||||||
) -> tuple[Graph, Stack]: ...
|
) -> tuple[Graph, Stack]: ...
|
||||||
def _tracer_warn_use_python(): ...
|
def _tracer_warn_use_python(): ...
|
||||||
def _get_tracing_state() -> TracingState: ...
|
def _get_tracing_state() -> TracingState: ...
|
||||||
|
|
@ -2458,8 +2451,6 @@ class InferredType:
|
||||||
def success(self) -> _bool: ...
|
def success(self) -> _bool: ...
|
||||||
def reason(self) -> str: ...
|
def reason(self) -> str: ...
|
||||||
|
|
||||||
R = TypeVar("R", bound=JitType)
|
|
||||||
|
|
||||||
class Type(JitType):
|
class Type(JitType):
|
||||||
def str(self) -> _str: ...
|
def str(self) -> _str: ...
|
||||||
def containedTypes(self) -> list[JitType]: ...
|
def containedTypes(self) -> list[JitType]: ...
|
||||||
|
|
@ -2558,16 +2549,18 @@ class UnionType(JitType):
|
||||||
|
|
||||||
class ClassType(JitType):
|
class ClassType(JitType):
|
||||||
def __init__(self, qualified_name: str) -> None: ...
|
def __init__(self, qualified_name: str) -> None: ...
|
||||||
def qualified_name(self) ->str: ...
|
def qualified_name(self) -> str: ...
|
||||||
|
|
||||||
class InterfaceType(JitType):
|
class InterfaceType(JitType):
|
||||||
def __init__(self, qualified_name: str) -> None: ...
|
def __init__(self, qualified_name: str) -> None: ...
|
||||||
def getMethod(self, name: str) -> FunctionSchema | None: ...
|
def getMethod(self, name: str) -> FunctionSchema | None: ...
|
||||||
def getMethodNames(self) -> list[str]: ...
|
def getMethodNames(self) -> list[str]: ...
|
||||||
|
|
||||||
class OptionalType(JitType, Generic[R]):
|
JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001
|
||||||
def __init__(self, a: JitType) -> None: ...
|
|
||||||
def getElementType(self) -> JitType: ...
|
class OptionalType(JitType, Generic[JitTypeT]):
|
||||||
|
def __init__(self, a: JitTypeT) -> None: ...
|
||||||
|
def getElementType(self) -> JitTypeT: ...
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ofTensor() -> OptionalType: ...
|
def ofTensor() -> OptionalType: ...
|
||||||
|
|
||||||
|
|
@ -2681,17 +2674,18 @@ def _fuse_to_static_module(
|
||||||
# Defined in torch/csrc/fx/node.cpp
|
# Defined in torch/csrc/fx/node.cpp
|
||||||
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
||||||
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
|
||||||
|
|
||||||
class _NodeBase:
|
class _NodeBase:
|
||||||
_erased: _bool
|
_erased: _bool
|
||||||
_prev: FxNode
|
_prev: FxNode
|
||||||
_next: FxNode
|
_next: FxNode
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Any,
|
graph: Any,
|
||||||
name: str,
|
name: str,
|
||||||
op: str,
|
op: str,
|
||||||
target: Any,
|
target: Any,
|
||||||
return_type: Any,
|
return_type: Any,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Callable, Optional
|
from typing import Callable
|
||||||
|
|
||||||
class Aggregation(Enum):
|
class Aggregation(Enum):
|
||||||
VALUE = ...
|
VALUE = ...
|
||||||
|
|
@ -48,9 +48,9 @@ class _WaitCounterTracker:
|
||||||
def __enter__(self) -> None: ...
|
def __enter__(self) -> None: ...
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
exec_type: Optional[type[BaseException]] = None,
|
exc_type: type[BaseException] | None = None,
|
||||||
exec_value: Optional[BaseException] = None,
|
exc_value: BaseException | None = None,
|
||||||
traceback: Optional[TracebackType] = None,
|
traceback: TracebackType | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
class _WaitCounter:
|
class _WaitCounter:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
# mypy: disable-error-code="type-arg"
|
# mypy: disable-error-code="type-arg"
|
||||||
|
|
||||||
from typing import Literal, overload, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal, overload
|
||||||
|
|
||||||
from torch import memory_format, Tensor
|
from torch import memory_format, Tensor
|
||||||
from torch.types import _bool, _device, _dtype, _int, _size
|
from torch.types import _bool, _device, _dtype, _int, _size
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal, Optional
|
from typing import Literal
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from torch._C import device, dtype, layout
|
from torch._C import device, dtype, layout
|
||||||
|
|
@ -73,7 +73,7 @@ class ProfilerConfig:
|
||||||
with_flops: bool,
|
with_flops: bool,
|
||||||
with_modules: bool,
|
with_modules: bool,
|
||||||
experimental_config: _ExperimentalConfig,
|
experimental_config: _ExperimentalConfig,
|
||||||
trace_id: Optional[str] = None,
|
trace_id: str | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
class _ProfilerEvent:
|
class _ProfilerEvent:
|
||||||
|
|
@ -243,4 +243,4 @@ class _RecordFunctionFast:
|
||||||
keyword_values: dict | None = None,
|
keyword_values: dict | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
def __enter__(self) -> None: ...
|
def __enter__(self) -> None: ...
|
||||||
def __exit__(self, *args: Any) -> None: ...
|
def __exit__(self, *exc_info: object) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -1,31 +1,11 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
from typing import (
|
from typing import Final, NoReturn
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
ContextManager,
|
|
||||||
Final,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
|
||||||
NamedTuple,
|
|
||||||
NoReturn,
|
|
||||||
overload,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from torch import (
|
from torch import SymInt, Tensor
|
||||||
contiguous_format,
|
from torch.types import ( # noqa: F401
|
||||||
Generator,
|
|
||||||
inf,
|
|
||||||
memory_format,
|
|
||||||
strided,
|
|
||||||
SymInt,
|
|
||||||
Tensor,
|
|
||||||
)
|
|
||||||
from torch.types import (
|
|
||||||
_bool,
|
_bool,
|
||||||
_device,
|
_device,
|
||||||
_dtype,
|
_dtype,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
|
||||||
from typing import Any, Callable, Literal, overload, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Literal, overload
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
|
||||||
|
|
@ -2383,7 +2383,7 @@ class CudaNonDefaultStream:
|
||||||
device_type=deviceStream.device_type)
|
device_type=deviceStream.device_type)
|
||||||
torch._C._cuda_setDevice(beforeDevice)
|
torch._C._cuda_setDevice(beforeDevice)
|
||||||
|
|
||||||
def __exit__(self, exec_type, exec_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
# After completing CUDA test load previously active streams on all
|
# After completing CUDA test load previously active streams on all
|
||||||
# CUDA devices.
|
# CUDA devices.
|
||||||
beforeDevice = torch.cuda.current_device()
|
beforeDevice = torch.cuda.current_device()
|
||||||
|
|
@ -2431,9 +2431,9 @@ class CudaMemoryLeakCheck:
|
||||||
driver_mem_allocated = bytes_total - bytes_free
|
driver_mem_allocated = bytes_total - bytes_free
|
||||||
self.driver_befores.append(driver_mem_allocated)
|
self.driver_befores.append(driver_mem_allocated)
|
||||||
|
|
||||||
def __exit__(self, exec_type, exec_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
# Don't check for leaks if an exception was thrown
|
# Don't check for leaks if an exception was thrown
|
||||||
if exec_type is not None:
|
if exc_type is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Compares caching allocator before/after statistics
|
# Compares caching allocator before/after statistics
|
||||||
|
|
|
||||||
|
|
@ -5,19 +5,8 @@
|
||||||
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
|
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
|
||||||
# classes/objects here, even though we are not injecting extra code into them at the moment.
|
# classes/objects here, even though we are not injecting extra code into them at the moment.
|
||||||
|
|
||||||
from typing import (
|
from collections.abc import Iterable, Iterator
|
||||||
Any,
|
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch.utils.data import Dataset, default_collate, IterableDataset
|
from torch.utils.data import Dataset, default_collate, IterableDataset
|
||||||
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
|
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
|
||||||
|
|
@ -27,19 +16,19 @@ _T = TypeVar("_T")
|
||||||
_T_co = TypeVar("_T_co", covariant=True)
|
_T_co = TypeVar("_T_co", covariant=True)
|
||||||
UNTRACABLE_DATAFRAME_PIPES: Any
|
UNTRACABLE_DATAFRAME_PIPES: Any
|
||||||
|
|
||||||
class DataChunk(List[_T]):
|
class DataChunk(list[_T]):
|
||||||
items: List[_T]
|
items: list[_T]
|
||||||
def __init__(self, items: Iterable[_T]) -> None: ...
|
def __init__(self, items: Iterable[_T]) -> None: ...
|
||||||
def as_str(self, indent: str = "") -> str: ...
|
def as_str(self, indent: str = "") -> str: ...
|
||||||
def __iter__(self) -> Iterator[_T]: ...
|
def __iter__(self) -> Iterator[_T]: ...
|
||||||
def raw_iterator(self) -> Iterator[_T]: ...
|
def raw_iterator(self) -> Iterator[_T]: ...
|
||||||
|
|
||||||
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
|
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
|
||||||
functions: Dict[str, Callable] = ...
|
functions: dict[str, Callable] = ...
|
||||||
reduce_ex_hook: Optional[Callable] = ...
|
reduce_ex_hook: Callable | None = ...
|
||||||
getstate_hook: Optional[Callable] = ...
|
getstate_hook: Callable | None = ...
|
||||||
str_hook: Optional[Callable] = ...
|
str_hook: Callable | None = ...
|
||||||
repr_hook: Optional[Callable] = ...
|
repr_hook: Callable | None = ...
|
||||||
def __getattr__(self, attribute_name: Any): ...
|
def __getattr__(self, attribute_name: Any): ...
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_function(cls, function_name: Any, function: Any) -> None: ...
|
def register_function(cls, function_name: Any, function: Any) -> None: ...
|
||||||
|
|
@ -58,7 +47,7 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
|
||||||
${MapDataPipeMethods}
|
${MapDataPipeMethods}
|
||||||
|
|
||||||
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
||||||
functions: Dict[str, Callable] = ...
|
functions: dict[str, Callable] = ...
|
||||||
reduce_ex_hook: Optional[Callable] = ...
|
reduce_ex_hook: Optional[Callable] = ...
|
||||||
getstate_hook: Optional[Callable] = ...
|
getstate_hook: Optional[Callable] = ...
|
||||||
str_hook: Optional[Callable] = ...
|
str_hook: Optional[Callable] = ...
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user