diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py index b2d78497525..f8599d39383 100644 --- a/torch/_dynamo/backends/cudagraphs.py +++ b/torch/_dynamo/backends/cudagraphs.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module implements CUDA graphs support for TorchDynamo backends. @@ -25,9 +23,11 @@ Key components: import functools from collections import defaultdict -from typing import Optional +from collections.abc import Sequence +from typing import Any, Callable, Optional import torch +import torch.fx from torch._dynamo import config from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.debugging import boxed_nop @@ -51,8 +51,8 @@ from torch.multiprocessing.reductions import StorageWeakRef from .registry import register_backend -def find_input_mutations(g): - def meta_fk(meta): +def find_input_mutations(g: torch.fx.Graph) -> set[int]: + def meta_fk(meta: dict[str, Any]) -> Any: return meta["val"] if "val" in meta else meta["fake_result"] inputs = defaultdict(set) @@ -90,7 +90,9 @@ def find_input_mutations(g): return mutated_inputs -def get_device_node_mapping(gm: torch.fx.GraphModule): +def get_device_node_mapping( + gm: torch.fx.GraphModule, +) -> dict[torch.device, torch.fx.Node]: device_node_mapping: dict[torch.device, torch.fx.Node] = {} for n in gm.graph.nodes: t = n.meta.get("val", None) @@ -100,7 +102,7 @@ def get_device_node_mapping(gm: torch.fx.GraphModule): def check_for_mutation_ignore_cuda_graph_managed_tensor( - aot_model: torch.fx.GraphModule, num_fixed + aot_model: torch.fx.GraphModule, num_fixed: int ) -> Optional[str]: mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) if not mutation_indices: @@ -110,7 +112,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor( return get_mutation_stack_trace(placeholders, mutation_indices) -def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]: if not config.cudagraph_backend_support_input_mutation: if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( aot_model, num_fixed @@ -128,28 +130,35 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: return None -def get_device_index(gm) -> int: +def get_device_index(gm: torch.fx.GraphModule) -> int: device = next(iter(get_device_node_mapping(gm))) assert device.type == "cuda" return device.index -def get_stack_traces(gm) -> list[Optional[str]]: +def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]: output = output_node(gm) assert len(output.args) == 1 + args = output.args[0] + if not hasattr(args, "__iter__"): + return [] return [ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) - for arg in output.args[0] + for arg in args # type: ignore[union-attr] ] -def cudagraphs(dynamo_model, dynamo_inputs): +def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any: from torch._inductor.cudagraph_trees import cudagraphify_impl do_cudagraphs = BoxedBool(True) boxed_device_index = BoxedDeviceIndex(None) - def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): + def forward_cudagraphs( + aot_model: torch.fx.GraphModule, + aot_inputs: list[Any], + is_inference: bool = False, + ) -> Any: interp = boxed_nop(aot_model, aot_inputs) fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) if skip_msg := check_for_skip(aot_model, fixed): @@ -166,15 +175,17 @@ def cudagraphs(dynamo_model, dynamo_inputs): range(fixed), device_index=boxed_device_index.value, is_backward=False, - is_inference=False, + is_inference=False, # Q: should forward is_inference here? stack_traces=get_stack_traces(aot_model), placeholders=get_placeholder_info(aot_model.graph), mutated_input_idxs=find_input_mutations(aot_model.graph), ) - out._boxed_call = True + out._boxed_call = True # type: ignore[attr-defined] return out - def backward_cudagraphs(aot_model, aot_inputs): + def backward_cudagraphs( + aot_model: torch.fx.GraphModule, aot_inputs: list[Any] + ) -> Any: interp = boxed_nop(aot_model, aot_inputs) if not do_cudagraphs: return aot_model @@ -182,20 +193,23 @@ def cudagraphs(dynamo_model, dynamo_inputs): fixed = count_tangents(aot_model) if skip_msg := check_for_skip(aot_model, fixed): log_cudagraph_skip_and_bump_counter( - "skipping cudagraphs due to %s", skip_msg + f"skipping cudagraphs due to {skip_msg}" ) # See [Backward Generation Handling] + device_idx = boxed_device_index.value + if device_idx is None: + device_idx = 0 # Default to device 0 if not set manager = torch._inductor.cudagraph_trees.get_manager( - boxed_device_index.value, create_if_none_exists=False + device_idx, create_if_none_exists=False ) assert manager is not None - def fn(inputs): + def fn(inputs: list[Any]) -> Any: manager.set_to_running_backward() return aot_model(inputs) - fn._boxed_call = True + fn._boxed_call = True # type: ignore[attr-defined] return fn out = cudagraphify_impl( @@ -209,7 +223,7 @@ def cudagraphs(dynamo_model, dynamo_inputs): placeholders=get_placeholder_info(aot_model.graph), mutated_input_idxs=find_input_mutations(aot_model.graph), ) - out._boxed_call = True + out._boxed_call = True # type: ignore[attr-defined] return out aot_cudagraphs = aot_autograd( @@ -225,13 +239,13 @@ class CudagraphsBackend: compiler_name = "cudagraphs" @staticmethod - def reset(): + def reset() -> None: from torch._inductor.cudagraph_trees import reset_cudagraph_trees reset_cudagraph_trees() @staticmethod - def __call__(model, inputs): + def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any: return cudagraphs(model, inputs) @@ -240,7 +254,12 @@ class CudagraphsBackend: register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) -def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): +def cudagraphs_inner( + model: Callable[..., Any], + inputs: Sequence[Any], + copy_outputs: bool = True, + copy_inputs: bool = True, +) -> Callable[..., Sequence[Any]]: """This isn't registered as a backend, but is used in some benchmarks""" assert isinstance(inputs, (list, tuple)) if copy_inputs: @@ -265,7 +284,7 @@ def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) - def run(*new_inputs): + def run(*new_inputs: Any) -> Sequence[Any]: assert len(static_inputs) == len(new_inputs) if copy_inputs: for dst, src in zip(static_inputs, new_inputs): diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 7826c797d36..e6281ad30e4 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -14,7 +14,7 @@ from .utils import is_using_cudagraph_partition if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Sequence, Set as AbstractSet perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -110,7 +110,8 @@ def format_default_skip_message(reason: str) -> str: def get_mutation_stack_trace( - placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] + placeholders: Sequence[PlaceholderInfo], + mutation_indices: Union[AbstractSet[int], Sequence[int]], ) -> str: stack_trace: Optional[str] = ""