diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 69d90d4e7a1..a5c4d390ee3 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -40,6 +40,7 @@ from torch._C import ( ) from torch._prims_common import DeviceLikeType from torch.autograd.graph import Node as _Node +from torch.cuda import _POOL_HANDLE from torch.fx.node import Node as FxNode from torch.package import PackageExporter from torch.storage import TypedStorage, UntypedStorage @@ -2289,7 +2290,7 @@ class _CUDAGraph: def __new__(cls, keep_graph: _bool = ...) -> Self: ... def capture_begin( self, - pool: tuple[_int, _int] | None = ..., + pool: _POOL_HANDLE | None = ..., capture_error_mode: str = "global", ) -> None: ... def capture_end(self) -> None: ... @@ -2297,7 +2298,7 @@ class _CUDAGraph: def register_generator_state(self, Generator) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... - def pool(self) -> tuple[_int, _int]: ... + def pool(self) -> _POOL_HANDLE: ... def enable_debug_mode(self) -> None: ... def debug_dump(self, debug_path: str) -> None: ... def raw_cuda_graph(self) -> _int: ... diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index bdc201803fb..3b3dea909cd 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -90,6 +90,7 @@ if TYPE_CHECKING: from torch._guards import CompileId from torch._inductor.utils import InputType + from torch.cuda import _POOL_HANDLE from torch.types import _bool StorageWeakRefPointer = int @@ -817,7 +818,7 @@ class CUDAGraphNode: id: GraphID, parent: Optional[CUDAGraphNode], inputs: list[InputType], - cuda_graphs_pool: tuple[int, int], + cuda_graphs_pool: _POOL_HANDLE, device_index: int, stack_traces: Optional[StackTraces], stream: torch.cuda.Stream, @@ -1228,6 +1229,7 @@ class CUDAGraphNode: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: "Record the model" + assert self.graph is not None def static_input_iter() -> Generator[torch.Tensor, None, None]: for i in self.wrapped_function.static_input_idxs: @@ -1310,13 +1312,11 @@ class CUDAGraphNode: self.output_storage_alias.append(UnaliasedStorage) continue - ( - torch._check( - o.is_cuda or o.untyped_storage().data_ptr() == 0, - lambda: ( - "Expected all cuda outputs in cuda graph recording. Non cuda output " - f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" - ), + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" ), ) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index b8036a5235b..ac3aad9f93b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,8 +4,8 @@ import inspect import itertools import warnings from collections import OrderedDict -from typing import Any, Optional -from typing_extensions import deprecated +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import Concatenate, deprecated, ParamSpec import torch import torch._C as _C @@ -29,6 +29,10 @@ __all__ = [ # This is incremented in FunctionMeta during class definition AUTOGRAD_FUNCTION_COUNTER = itertools.count() +_T = TypeVar("_T") +_R = TypeVar("_R") +_P = ParamSpec("_P") + # Formerly known as: _ContextMethodMixin class FunctionCtx: @@ -595,11 +599,13 @@ def _is_setup_context_defined(fn): return fn != _SingleLevelFunction.setup_context -def once_differentiable(fn): +def once_differentiable( + fn: Callable[Concatenate[_T, _P], _R], +) -> Callable[Concatenate[_T, _P], _R]: @functools.wraps(fn) - def wrapper(ctx, *args): + def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: with torch.no_grad(): - outputs = fn(ctx, *args) + outputs = fn(ctx, *args, **kwargs) if not torch.is_grad_enabled(): return outputs @@ -620,12 +626,14 @@ def once_differentiable(fn): return outputs if not isinstance(outputs, tuple): - outputs = (outputs,) + outputs_ = (outputs,) + else: + outputs_ = outputs err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked " b"with @once_differentiable", - len(outputs), + len(outputs_), ) # Create aliases of each output that has requires_grad=True. We need @@ -637,7 +645,7 @@ def once_differentiable(fn): var.requires_grad = True return var - return err_fn(*[fake_requires_grad(v) for v in outputs]) + return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value] return wrapper diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index fc9d09ce63a..5b85c91d2c2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -18,7 +18,7 @@ import threading import traceback import warnings from functools import lru_cache -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union import torch import torch._C @@ -1777,6 +1777,9 @@ def _compile_kernel( from . import amp, jiterator, nvtx, profiler, sparse, tunable +_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int]) + + __all__ = [ # Typed storage and tensors "BFloat16Storage", diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index b58a7808593..b1d1e4f8c47 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,12 +1,34 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import gc import typing +from typing import Callable, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar import torch +from torch import Tensor + + +if TYPE_CHECKING: + # importing _POOL_HANDLE at runtime toplevel causes an import cycle + from torch.cuda import _POOL_HANDLE from .._utils import _dummy_type +__all__ = [ + "is_current_stream_capturing", + "graph_pool_handle", + "CUDAGraph", + "graph", + "make_graphed_callables", +] + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") @@ -22,7 +44,7 @@ from torch._C import ( # noqa: F401 ) -def is_current_stream_capturing(): +def is_current_stream_capturing() -> bool: r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. If a CUDA context does not exist on the current device, returns False without initializing the context. @@ -31,7 +53,7 @@ def is_current_stream_capturing(): # Python shim helps Sphinx process docstrings more reliably. -def graph_pool_handle(): +def graph_pool_handle() -> _POOL_HANDLE: r"""Return an opaque token representing the id of a graph memory pool. See :ref:`Graph memory management`. @@ -39,7 +61,7 @@ def graph_pool_handle(): .. warning:: This API is in beta and may change in future releases. """ - return _graph_pool_handle() + return torch.cuda._POOL_HANDLE(_graph_pool_handle()) # Python shim helps Sphinx process docstrings more reliably. @@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph): """ - def __new__(cls, keep_graph=False): + def __new__(cls, keep_graph: bool = False) -> Self: return super().__new__(cls, keep_graph) - def capture_begin(self, pool=None, capture_error_mode="global"): + def capture_begin( + self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global" + ) -> None: r"""Begin capturing CUDA work on the current stream. Typically, you shouldn't call ``capture_begin`` yourself. @@ -92,7 +116,7 @@ class CUDAGraph(torch._C._CUDAGraph): """ # noqa: B950 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) - def capture_end(self): + def capture_end(self) -> None: r"""End CUDA graph capture on the current stream. After ``capture_end``, ``replay`` may be called on this instance. @@ -103,7 +127,7 @@ class CUDAGraph(torch._C._CUDAGraph): """ super().capture_end() - def instantiate(self): + def instantiate(self) -> None: r"""Instantiate the CUDA graph. Will be called by ``capture_end`` if ``keep_graph=False``, or by ``replay`` if ``keep_graph=True`` and ``instantiate`` has not already been @@ -112,15 +136,15 @@ class CUDAGraph(torch._C._CUDAGraph): """ super().instantiate() - def replay(self): + def replay(self) -> None: r"""Replay the CUDA work captured by this graph.""" super().replay() - def reset(self): + def reset(self) -> None: r"""Delete the graph currently held by this instance.""" super().reset() - def pool(self): + def pool(self) -> _POOL_HANDLE: r"""Return an opaque token representing the id of this graph's memory pool. This id can optionally be passed to another graph's ``capture_begin``, @@ -128,11 +152,11 @@ class CUDAGraph(torch._C._CUDAGraph): """ return super().pool() - def enable_debug_mode(self): + def enable_debug_mode(self) -> None: r"""Enable debugging mode for CUDAGraph.debug_dump.""" return super().enable_debug_mode() - def debug_dump(self, debug_path): + def debug_dump(self, debug_path: str) -> None: r""" Arguments: debug_path (required): Path to dump the graph to. @@ -142,7 +166,7 @@ class CUDAGraph(torch._C._CUDAGraph): """ return super().debug_dump(debug_path) - def raw_cuda_graph(self): + def raw_cuda_graph(self) -> int: r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ @@ -180,13 +204,13 @@ class graph: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 """ # noqa: B950 - default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + default_capture_stream: Optional[torch.cuda.Stream] = None def __init__( self, - cuda_graph, - pool=None, - stream=None, + cuda_graph: CUDAGraph, + pool: Optional[_POOL_HANDLE] = None, + stream: Optional[torch.cuda.Stream] = None, capture_error_mode: str = "global", ): # Lazy-init of default_capture_stream helps avoid circular-import errors. @@ -195,7 +219,9 @@ class graph: if self.__class__.default_capture_stream is None: self.__class__.default_capture_stream = torch.cuda.Stream() - self.pool = () if pool is None else (pool,) + self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( + () if pool is None else (pool,) + ) self.capture_stream = ( stream if stream is not None else self.__class__.default_capture_stream ) @@ -204,7 +230,7 @@ class graph: self.cuda_graph = cuda_graph self.capture_error_mode = capture_error_mode - def __enter__(self): + def __enter__(self) -> None: # Free as much memory as we can for the graph torch.cuda.synchronize() gc.collect() @@ -215,18 +241,47 @@ class graph: self.stream_ctx.__enter__() self.cuda_graph.capture_begin( - *self.pool, capture_error_mode=self.capture_error_mode + # type: ignore[misc] + *self.pool, + capture_error_mode=self.capture_error_mode, ) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *args: object) -> None: self.cuda_graph.capture_end() - self.stream_ctx.__exit__(exc_type, exc_value, traceback) + self.stream_ctx.__exit__(*args) # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() +_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]] + + +@overload def make_graphed_callables( - callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None -): + callables: _ModuleOrCallable, + sample_args: tuple[Tensor, ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> _ModuleOrCallable: ... + + +@overload +def make_graphed_callables( + callables: tuple[_ModuleOrCallable, ...], + sample_args: tuple[tuple[Tensor, ...], ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> tuple[_ModuleOrCallable, ...]: ... + + +def make_graphed_callables( + callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]], + sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]: r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. Each graphed callable's forward pass runs its source callable's @@ -300,14 +355,17 @@ def make_graphed_callables( just_one_callable = False + _sample_args: tuple[tuple[Tensor, ...], ...] if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) - sample_args = (sample_args,) + _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),) + else: + _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args) flatten_sample_args = [] - for c, args in zip(callables, sample_args): + for c, args in zip(callables, _sample_args): if isinstance(c, torch.nn.Module): assert ( len(c._backward_hooks) == 0 @@ -352,7 +410,7 @@ def make_graphed_callables( torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): for func, args, static_input_surface in zip( - callables, sample_args, per_callable_static_input_surfaces + callables, _sample_args, per_callable_static_input_surfaces ): grad_inputs, outputs, outputs_grad = None, None, None for _ in range(num_warmup_iters): @@ -382,11 +440,11 @@ def make_graphed_callables( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + func_outputs = func(*args) - flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) + flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_output_unflatten_spec.append(spec) @@ -438,19 +496,19 @@ def make_graphed_callables( # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( - fwd_graph, - bwd_graph, - module_params, - len_user_args, - output_unflatten_spec, - static_input_surface, - static_outputs, - static_grad_outputs, - static_grad_inputs, - ): + fwd_graph: CUDAGraph, + bwd_graph: CUDAGraph, + module_params: tuple[torch.nn.Parameter, ...], + len_user_args: int, + output_unflatten_spec: torch.utils._pytree.TreeSpec, + static_input_surface: tuple[Tensor, ...], + static_outputs: tuple[Tensor, ...], + static_grad_outputs: tuple[Optional[Tensor], ...], + static_grad_inputs: tuple[Tensor, ...], + ) -> Callable[..., object]: class Graphed(torch.autograd.Function): @staticmethod - def forward(ctx, *inputs): + def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]: # At this stage, only the user args may (potentially) be new tensors. for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): @@ -461,7 +519,7 @@ def make_graphed_callables( @staticmethod @torch.autograd.function.once_differentiable - def backward(ctx, *grads): + def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]: assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -477,7 +535,7 @@ def make_graphed_callables( b.detach() if b is not None else b for b in static_grad_inputs ) - def functionalized(*user_args): + def functionalized(*user_args: object) -> object: # Runs the autograd function with inputs == all inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. @@ -488,7 +546,7 @@ def make_graphed_callables( return functionalized # Put together the final graphed callables - ret = [] + ret: list[_ModuleOrCallable] = [] for i, func in enumerate(callables): graphed = make_graphed_autograd_function( fwd_graphs[i], @@ -504,20 +562,25 @@ def make_graphed_callables( if isinstance(func, torch.nn.Module): - def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): - def new_fwd(*user_args): + def make_graphed_forward( + func: torch.nn.Module, + graph_training_state: bool, + graphed: Callable[_P, _R], + orig_fwd: Callable[_P, _R], + ) -> Callable[_P, _R]: + def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R: # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method if func.training == graph_training_state: - return graphed(*user_args) + return graphed(*user_args, **user_kwargs) else: - return orig_fwd(*user_args) + return orig_fwd(*user_args, **user_kwargs) return new_fwd func.forward = make_graphed_forward( func, func.training, graphed, func.forward - ) # type: ignore[assignment] + ) ret.append(func) else: ret.append(graphed) diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 3557ef09930..7cde0bd3517 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -28,7 +28,7 @@ class _LazyModule: # NOTE: Add additional used imports here. if TYPE_CHECKING: import onnx - import onnx_ir # type: ignore[import-untyped] + import onnx_ir # type: ignore[import-untyped, import-not-found] import onnxscript import onnxscript._framework_apis.torch_2_8 as onnxscript_apis