[BE][CI][Easy] Run lintrunner on generated .pyi stub files (#150732)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150732
Approved by: https://github.com/malfet, https://github.com/cyyever, https://github.com/aorenste
This commit is contained in:
Xuehai Pan 2025-05-27 20:23:17 +08:00 committed by PyTorch MergeBot
parent 0a7eef140b
commit 7ae204c3b6
12 changed files with 95 additions and 128 deletions

View File

@ -31,6 +31,9 @@ python3 -m tools.pyi.gen_pyi \
--deprecated-functions-path "tools/autograd/deprecated.yaml" --deprecated-functions-path "tools/autograd/deprecated.yaml"
python3 torch/utils/data/datapipes/gen_pyi.py python3 torch/utils/data/datapipes/gen_pyi.py
# Also check generated pyi files
find torch -name '*.pyi' -exec git add --force -- "{}" +
RC=0 RC=0
# Run lintrunner on all files # Run lintrunner on all files
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
@ -41,6 +44,9 @@ if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS}
RC=1 RC=1
fi fi
# Unstage temporally added pyi files
find torch -name '*.pyi' -exec git restore --staged -- "{}" +
# Use jq to massage the JSON lint output into GitHub Actions workflow commands. # Use jq to massage the JSON lint output into GitHub Actions workflow commands.
jq --raw-output \ jq --raw-output \
'"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \ '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \

View File

@ -212,6 +212,11 @@ select = [
"__init__.py" = [ "__init__.py" = [
"F401", "F401",
] ]
"*.pyi" = [
"PYI011", # typed-argument-default-in-stub
"PYI021", # docstring-in-stub
"PYI053", # string-or-bytes-too-long
]
"functorch/notebooks/**" = [ "functorch/notebooks/**" = [
"F401", "F401",
] ]

View File

@ -113,9 +113,9 @@ class MpsMemoryLeakCheck:
self.caching_allocator_before = torch.mps.current_allocated_memory() self.caching_allocator_before = torch.mps.current_allocated_memory()
self.driver_before = torch.mps.driver_allocated_memory() self.driver_before = torch.mps.driver_allocated_memory()
def __exit__(self, exec_type, exec_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
# Don't check for leaks if an exception was thrown # Don't check for leaks if an exception was thrown
if exec_type is not None: if exc_type is not None:
return return
# Compares caching allocator before/after statistics # Compares caching allocator before/after statistics
# An increase in allocated memory is a discrepancy indicating a possible memory leak # An increase in allocated memory is a discrepancy indicating a possible memory leak

View File

@ -1,20 +1,11 @@
# ${generated_comment} # ${generated_comment}
# mypy: disable-error-code="type-arg" # mypy: disable-error-code="type-arg"
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# ruff: noqa: F401,PYI054
import builtins from collections.abc import Sequence
from types import EllipsisType from types import EllipsisType
from typing import ( from typing import Any, Callable, Literal, overload, TypeVar
Any,
Callable,
ContextManager,
Iterator,
Literal,
NamedTuple,
overload,
Sequence,
TypeVar,
)
import torch import torch
from torch import ( from torch import (

View File

@ -1,8 +1,9 @@
# ${generated_comment} # ${generated_comment}
# mypy: disable-error-code="type-arg" # mypy: disable-error-code="type-arg"
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# ruff: noqa: F401
import builtins from collections.abc import Iterable, Iterator, Sequence
from enum import Enum, IntEnum from enum import Enum, IntEnum
from pathlib import Path from pathlib import Path
from types import EllipsisType from types import EllipsisType
@ -10,17 +11,13 @@ from typing import (
Any, Any,
AnyStr, AnyStr,
Callable, Callable,
ContextManager,
Generic, Generic,
IO, IO,
Iterable,
Iterator,
Literal, Literal,
NamedTuple, NamedTuple,
overload, overload,
Protocol, Protocol,
runtime_checkable, runtime_checkable,
Sequence,
SupportsIndex, SupportsIndex,
TypeVar, TypeVar,
) )
@ -71,15 +68,15 @@ from torch.utils._python_dispatch import TorchDispatchMode
# This module is defined in torch/csrc/Module.cpp # This module is defined in torch/csrc/Module.cpp
K = TypeVar("K") K = TypeVar("K") # noqa: PYI001
T = TypeVar("T") T = TypeVar("T") # noqa: PYI001
S = TypeVar("S", bound=torch.Tensor) S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001
P = ParamSpec("P") P = ParamSpec("P") # noqa: PYI001
ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant) R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001
_T_co = TypeVar("_T_co", covariant=True) T_co = TypeVar("T_co", covariant=True) # noqa: PYI001
@runtime_checkable @runtime_checkable
class _NestedSequence(Protocol[_T_co]): class _NestedSequence(Protocol[T_co]):
"""A protocol for representing nested sequences. """A protocol for representing nested sequences.
References:: References::
@ -88,10 +85,10 @@ class _NestedSequence(Protocol[_T_co]):
""" """
def __len__(self, /) -> _int: ... def __len__(self, /) -> _int: ...
def __getitem__(self, index: _int, /) -> _T_co | _NestedSequence[_T_co]: ... def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ...
def __contains__(self, x: object, /) -> _bool: ... def __contains__(self, x: object, /) -> _bool: ...
def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
def count(self, value: Any, /) -> _int: ... def count(self, value: Any, /) -> _int: ...
def index(self, value: Any, /) -> _int: ... def index(self, value: Any, /) -> _int: ...
@ -146,7 +143,7 @@ class Stream:
def record_event(self, event: Event | None = None) -> Event: ... def record_event(self, event: Event | None = None) -> Event: ...
def __hash__(self) -> _int: ... def __hash__(self) -> _int: ...
def __eq__(self, other: object) -> _bool: ... def __eq__(self, other: object) -> _bool: ...
def __enter__(self) -> Stream: ... def __enter__(self) -> Self: ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
# Defined in torch/csrc/Event.cpp # Defined in torch/csrc/Event.cpp
@ -321,14 +318,14 @@ def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
def unify_type_list(types: list[JitType]) -> JitType: ... def unify_type_list(types: list[JitType]) -> JitType: ...
def _freeze_module( def _freeze_module(
module: ScriptModule, module: ScriptModule,
preserved_attrs: list[str] = [], preserved_attrs: list[str] = ...,
freeze_interfaces: _bool = True, freeze_interfaces: _bool = True,
preserveParameters: _bool = True, preserveParameters: _bool = True,
) -> ScriptModule: ... ) -> ScriptModule: ...
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
def _jit_pass_optimize_for_inference( def _jit_pass_optimize_for_inference(
module: torch.jit.ScriptModule, module: torch.jit.ScriptModule,
other_methods: list[str] = [], other_methods: list[str] = ...,
) -> None: ... ) -> None: ...
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
@ -759,7 +756,7 @@ class AliasDb: ...
class _InsertPoint: class _InsertPoint:
def __enter__(self) -> None: ... def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
# Defined in torch/csrc/jit/ir/ir.h # Defined in torch/csrc/jit/ir/ir.h
class Use: class Use:
@ -1078,8 +1075,8 @@ class LiteScriptModule:
def run_method(self, method_name: str, *input): ... def run_method(self, method_name: str, *input): ...
# NOTE: switch to collections.abc.Callable in python 3.9 # NOTE: switch to collections.abc.Callable in python 3.9
class ScriptFunction(Generic[P, ReturnVal]): class ScriptFunction(Generic[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ... def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ...
def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ... def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ...
@property @property
@ -1092,9 +1089,9 @@ class ScriptFunction(Generic[P, ReturnVal]):
def qualified_name(self) -> str: ... def qualified_name(self) -> str: ...
# NOTE: switch to collections.abc.Callable in python 3.9 # NOTE: switch to collections.abc.Callable in python 3.9
class ScriptMethod(Generic[P, ReturnVal]): class ScriptMethod(Generic[P, R]):
graph: Graph graph: Graph
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
@property @property
def owner(self) -> ScriptModule: ... def owner(self) -> ScriptModule: ...
@property @property
@ -1481,9 +1478,7 @@ def _get_function_stack_at(idx: _int) -> Any: ...
def _len_torch_function_stack() -> _int: ... def _len_torch_function_stack() -> _int: ...
def _set_torch_dispatch_mode(cls: Any) -> None: ... def _set_torch_dispatch_mode(cls: Any) -> None: ...
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ... def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
def _pop_torch_dispatch_stack( def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ...
mode_key: _TorchDispatchModeKey | None = None,
) -> Any: ...
def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ... def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ...
def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ... def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ...
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ... def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
@ -1494,42 +1489,42 @@ def _activate_gpu_trace() -> None: ...
class _DisableTorchDispatch: class _DisableTorchDispatch:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _EnableTorchFunction: class _EnableTorchFunction:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _EnablePythonDispatcher: class _EnablePythonDispatcher:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _DisablePythonDispatcher: class _DisablePythonDispatcher:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _EnablePreDispatch: class _EnablePreDispatch:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _DisableFuncTorch: class _DisableFuncTorch:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _DisableAutocast: class _DisableAutocast:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _InferenceMode: class _InferenceMode:
def __init__(self, enabled: _bool) -> None: ... def __init__(self, enabled: _bool) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
def _set_autograd_fallback_mode(mode: str) -> None: ... def _set_autograd_fallback_mode(mode: str) -> None: ...
def _get_autograd_fallback_mode() -> str: ... def _get_autograd_fallback_mode() -> str: ...
@ -1783,32 +1778,32 @@ def _commit_update(a: Tensor) -> None: ...
class _ExcludeDispatchKeyGuard: class _ExcludeDispatchKeyGuard:
def __init__(self, keyset: DispatchKeySet) -> None: ... def __init__(self, keyset: DispatchKeySet) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _IncludeDispatchKeyGuard: class _IncludeDispatchKeyGuard:
def __init__(self, k: DispatchKey) -> None: ... def __init__(self, k: DispatchKey) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _ForceDispatchKeyGuard: class _ForceDispatchKeyGuard:
def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ... def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _PreserveDispatchKeyGuard: class _PreserveDispatchKeyGuard:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _AutoDispatchBelowAutograd: class _AutoDispatchBelowAutograd:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
class _AutoDispatchBelowADInplaceOrView: class _AutoDispatchBelowADInplaceOrView:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
def _dispatch_get_registrations_for_dispatch_key( def _dispatch_get_registrations_for_dispatch_key(
@ -1827,18 +1822,16 @@ class _TorchDispatchModeKey(Enum):
class _SetExcludeDispatchKeyGuard: class _SetExcludeDispatchKeyGuard:
def __init__(self, k: DispatchKey, enabled: _bool) -> None: ... def __init__(self, k: DispatchKey, enabled: _bool) -> None: ...
def __enter__(self): ... def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...
# Defined in torch/csrc/utils/schema_info.h # Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo: class _SchemaInfo:
def __init__(self, schema: _int) -> None: ... def __init__(self, schema: _int) -> None: ...
@overload @overload
def is_mutable(self) -> _bool: ... def is_mutable(self) -> _bool: ...
@overload @overload
def is_mutable(self, name: str) -> _bool: ... def is_mutable(self, name: str) -> _bool: ...
def has_argument(self, name: str) -> _bool: ... def has_argument(self, name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp # Defined in torch/csrc/utils/init.cpp
@ -2431,7 +2424,7 @@ def _create_graph_by_tracing(
strict: Any, strict: Any,
force_outplace: Any, force_outplace: Any,
self: Any = None, self: Any = None,
argument_names: list[str] = [], argument_names: list[str] = ...,
) -> tuple[Graph, Stack]: ... ) -> tuple[Graph, Stack]: ...
def _tracer_warn_use_python(): ... def _tracer_warn_use_python(): ...
def _get_tracing_state() -> TracingState: ... def _get_tracing_state() -> TracingState: ...
@ -2458,8 +2451,6 @@ class InferredType:
def success(self) -> _bool: ... def success(self) -> _bool: ...
def reason(self) -> str: ... def reason(self) -> str: ...
R = TypeVar("R", bound=JitType)
class Type(JitType): class Type(JitType):
def str(self) -> _str: ... def str(self) -> _str: ...
def containedTypes(self) -> list[JitType]: ... def containedTypes(self) -> list[JitType]: ...
@ -2565,9 +2556,11 @@ class InterfaceType(JitType):
def getMethod(self, name: str) -> FunctionSchema | None: ... def getMethod(self, name: str) -> FunctionSchema | None: ...
def getMethodNames(self) -> list[str]: ... def getMethodNames(self) -> list[str]: ...
class OptionalType(JitType, Generic[R]): JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001
def __init__(self, a: JitType) -> None: ...
def getElementType(self) -> JitType: ... class OptionalType(JitType, Generic[JitTypeT]):
def __init__(self, a: JitTypeT) -> None: ...
def getElementType(self) -> JitTypeT: ...
@staticmethod @staticmethod
def ofTensor() -> OptionalType: ... def ofTensor() -> OptionalType: ...
@ -2681,6 +2674,7 @@ def _fuse_to_static_module(
# Defined in torch/csrc/fx/node.cpp # Defined in torch/csrc/fx/node.cpp
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ... def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ... def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
class _NodeBase: class _NodeBase:
_erased: _bool _erased: _bool
_prev: FxNode _prev: FxNode

View File

@ -3,7 +3,7 @@
import datetime import datetime
from enum import Enum from enum import Enum
from types import TracebackType from types import TracebackType
from typing import Callable, Optional from typing import Callable
class Aggregation(Enum): class Aggregation(Enum):
VALUE = ... VALUE = ...
@ -48,9 +48,9 @@ class _WaitCounterTracker:
def __enter__(self) -> None: ... def __enter__(self) -> None: ...
def __exit__( def __exit__(
self, self,
exec_type: Optional[type[BaseException]] = None, exc_type: type[BaseException] | None = None,
exec_value: Optional[BaseException] = None, exc_value: BaseException | None = None,
traceback: Optional[TracebackType] = None, traceback: TracebackType | None = None,
) -> None: ... ) -> None: ...
class _WaitCounter: class _WaitCounter:

View File

@ -1,7 +1,8 @@
# ${generated_comment} # ${generated_comment}
# mypy: disable-error-code="type-arg" # mypy: disable-error-code="type-arg"
from typing import Literal, overload, Sequence from collections.abc import Sequence
from typing import Literal, overload
from torch import memory_format, Tensor from torch import memory_format, Tensor
from torch.types import _bool, _device, _dtype, _int, _size from torch.types import _bool, _device, _dtype, _int, _size

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Literal, Optional from typing import Literal
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from torch._C import device, dtype, layout from torch._C import device, dtype, layout
@ -73,7 +73,7 @@ class ProfilerConfig:
with_flops: bool, with_flops: bool,
with_modules: bool, with_modules: bool,
experimental_config: _ExperimentalConfig, experimental_config: _ExperimentalConfig,
trace_id: Optional[str] = None, trace_id: str | None = None,
) -> None: ... ) -> None: ...
class _ProfilerEvent: class _ProfilerEvent:
@ -243,4 +243,4 @@ class _RecordFunctionFast:
keyword_values: dict | None = None, keyword_values: dict | None = None,
) -> None: ... ) -> None: ...
def __enter__(self) -> None: ... def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ... def __exit__(self, *exc_info: object) -> None: ...

View File

@ -1,31 +1,11 @@
# ${generated_comment} # ${generated_comment}
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import ( from typing import Final, NoReturn
Any,
Callable,
ContextManager,
Final,
Iterator,
Literal,
NamedTuple,
NoReturn,
overload,
Sequence,
TypeVar,
)
from typing_extensions import Self from typing_extensions import Self
from torch import ( from torch import SymInt, Tensor
contiguous_format, from torch.types import ( # noqa: F401
Generator,
inf,
memory_format,
strided,
SymInt,
Tensor,
)
from torch.types import (
_bool, _bool,
_device, _device,
_dtype, _dtype,

View File

@ -1,7 +1,8 @@
# ${generated_comment} # ${generated_comment}
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Callable, Literal, overload, Sequence from collections.abc import Sequence
from typing import Any, Callable, Literal, overload
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from torch import Tensor from torch import Tensor

View File

@ -2383,7 +2383,7 @@ class CudaNonDefaultStream:
device_type=deviceStream.device_type) device_type=deviceStream.device_type)
torch._C._cuda_setDevice(beforeDevice) torch._C._cuda_setDevice(beforeDevice)
def __exit__(self, exec_type, exec_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
# After completing CUDA test load previously active streams on all # After completing CUDA test load previously active streams on all
# CUDA devices. # CUDA devices.
beforeDevice = torch.cuda.current_device() beforeDevice = torch.cuda.current_device()
@ -2431,9 +2431,9 @@ class CudaMemoryLeakCheck:
driver_mem_allocated = bytes_total - bytes_free driver_mem_allocated = bytes_total - bytes_free
self.driver_befores.append(driver_mem_allocated) self.driver_befores.append(driver_mem_allocated)
def __exit__(self, exec_type, exec_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
# Don't check for leaks if an exception was thrown # Don't check for leaks if an exception was thrown
if exec_type is not None: if exc_type is not None:
return return
# Compares caching allocator before/after statistics # Compares caching allocator before/after statistics

View File

@ -5,19 +5,8 @@
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
# classes/objects here, even though we are not injecting extra code into them at the moment. # classes/objects here, even though we are not injecting extra code into them at the moment.
from typing import ( from collections.abc import Iterable, Iterator
Any, from typing import Any, Callable, Literal, Optional, TypeVar, Union
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from torch.utils.data import Dataset, default_collate, IterableDataset from torch.utils.data import Dataset, default_collate, IterableDataset
from torch.utils.data.datapipes._hook_iterator import _SnapshotState from torch.utils.data.datapipes._hook_iterator import _SnapshotState
@ -27,19 +16,19 @@ _T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True) _T_co = TypeVar("_T_co", covariant=True)
UNTRACABLE_DATAFRAME_PIPES: Any UNTRACABLE_DATAFRAME_PIPES: Any
class DataChunk(List[_T]): class DataChunk(list[_T]):
items: List[_T] items: list[_T]
def __init__(self, items: Iterable[_T]) -> None: ... def __init__(self, items: Iterable[_T]) -> None: ...
def as_str(self, indent: str = "") -> str: ... def as_str(self, indent: str = "") -> str: ...
def __iter__(self) -> Iterator[_T]: ... def __iter__(self) -> Iterator[_T]: ...
def raw_iterator(self) -> Iterator[_T]: ... def raw_iterator(self) -> Iterator[_T]: ...
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
functions: Dict[str, Callable] = ... functions: dict[str, Callable] = ...
reduce_ex_hook: Optional[Callable] = ... reduce_ex_hook: Callable | None = ...
getstate_hook: Optional[Callable] = ... getstate_hook: Callable | None = ...
str_hook: Optional[Callable] = ... str_hook: Callable | None = ...
repr_hook: Optional[Callable] = ... repr_hook: Callable | None = ...
def __getattr__(self, attribute_name: Any): ... def __getattr__(self, attribute_name: Any): ...
@classmethod @classmethod
def register_function(cls, function_name: Any, function: Any) -> None: ... def register_function(cls, function_name: Any, function: Any) -> None: ...
@ -58,7 +47,7 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
${MapDataPipeMethods} ${MapDataPipeMethods}
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
functions: Dict[str, Callable] = ... functions: dict[str, Callable] = ...
reduce_ex_hook: Optional[Callable] = ... reduce_ex_hook: Optional[Callable] = ...
getstate_hook: Optional[Callable] = ... getstate_hook: Optional[Callable] = ...
str_hook: Optional[Callable] = ... str_hook: Optional[Callable] = ...