Type cudagraphs.py (#160363)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160363
Approved by: https://github.com/StrongerXi
ghstack dependencies: #160362
This commit is contained in:
Lucas Kabela 2025-08-13 15:40:09 -07:00 committed by PyTorch MergeBot
parent f82c7eed84
commit 6fe6dd9fdc
2 changed files with 47 additions and 27 deletions

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
This module implements CUDA graphs support for TorchDynamo backends.
@ -25,9 +23,11 @@ Key components:
import functools
from collections import defaultdict
from typing import Optional
from collections.abc import Sequence
from typing import Any, Callable, Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
@ -51,8 +51,8 @@ from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
def find_input_mutations(g):
def meta_fk(meta):
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
def meta_fk(meta: dict[str, Any]) -> Any:
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
@ -90,7 +90,9 @@ def find_input_mutations(g):
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
def get_device_node_mapping(
gm: torch.fx.GraphModule,
) -> dict[torch.device, torch.fx.Node]:
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
@ -100,7 +102,7 @@ def get_device_node_mapping(gm: torch.fx.GraphModule):
def check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model: torch.fx.GraphModule, num_fixed
aot_model: torch.fx.GraphModule, num_fixed: int
) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
@ -110,7 +112,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor(
return get_mutation_stack_trace(placeholders, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
if not config.cudagraph_backend_support_input_mutation:
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
aot_model, num_fixed
@ -128,28 +130,35 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
return None
def get_device_index(gm) -> int:
def get_device_index(gm: torch.fx.GraphModule) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> list[Optional[str]]:
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
args = output.args[0]
if not hasattr(args, "__iter__"):
return []
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
for arg in args # type: ignore[union-attr]
]
def cudagraphs(dynamo_model, dynamo_inputs):
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
def forward_cudagraphs(
aot_model: torch.fx.GraphModule,
aot_inputs: list[Any],
is_inference: bool = False,
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
@ -166,15 +175,17 @@ def cudagraphs(dynamo_model, dynamo_inputs):
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
is_inference=False, # Q: should forward is_inference here?
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
out._boxed_call = True # type: ignore[attr-defined]
return out
def backward_cudagraphs(aot_model, aot_inputs):
def backward_cudagraphs(
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
) -> Any:
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
@ -182,20 +193,23 @@ def cudagraphs(dynamo_model, dynamo_inputs):
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
log_cudagraph_skip_and_bump_counter(
"skipping cudagraphs due to %s", skip_msg
f"skipping cudagraphs due to {skip_msg}"
)
# See [Backward Generation Handling]
device_idx = boxed_device_index.value
if device_idx is None:
device_idx = 0 # Default to device 0 if not set
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
device_idx, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
def fn(inputs: list[Any]) -> Any:
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
fn._boxed_call = True # type: ignore[attr-defined]
return fn
out = cudagraphify_impl(
@ -209,7 +223,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
out._boxed_call = True # type: ignore[attr-defined]
return out
aot_cudagraphs = aot_autograd(
@ -225,13 +239,13 @@ class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
def reset() -> None:
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
return cudagraphs(model, inputs)
@ -240,7 +254,12 @@ class CudagraphsBackend:
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
def cudagraphs_inner(
model: Callable[..., Any],
inputs: Sequence[Any],
copy_outputs: bool = True,
copy_inputs: bool = True,
) -> Callable[..., Sequence[Any]]:
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
@ -265,7 +284,7 @@ def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
def run(*new_inputs: Any) -> Sequence[Any]:
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):

View File

@ -14,7 +14,7 @@ from .utils import is_using_cudagraph_partition
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Sequence, Set as AbstractSet
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
@ -110,7 +110,8 @@ def format_default_skip_message(reason: str) -> str:
def get_mutation_stack_trace(
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
placeholders: Sequence[PlaceholderInfo],
mutation_indices: Union[AbstractSet[int], Sequence[int]],
) -> str:
stack_trace: Optional[str] = ""