mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Type cudagraphs.py (#160363)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160363 Approved by: https://github.com/StrongerXi ghstack dependencies: #160362
This commit is contained in:
parent
f82c7eed84
commit
6fe6dd9fdc
|
|
@ -1,5 +1,3 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements CUDA graphs support for TorchDynamo backends.
|
||||
|
||||
|
|
@ -25,9 +23,11 @@ Key components:
|
|||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.backends.debugging import boxed_nop
|
||||
|
|
@ -51,8 +51,8 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
|||
from .registry import register_backend
|
||||
|
||||
|
||||
def find_input_mutations(g):
|
||||
def meta_fk(meta):
|
||||
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
|
||||
def meta_fk(meta: dict[str, Any]) -> Any:
|
||||
return meta["val"] if "val" in meta else meta["fake_result"]
|
||||
|
||||
inputs = defaultdict(set)
|
||||
|
|
@ -90,7 +90,9 @@ def find_input_mutations(g):
|
|||
return mutated_inputs
|
||||
|
||||
|
||||
def get_device_node_mapping(gm: torch.fx.GraphModule):
|
||||
def get_device_node_mapping(
|
||||
gm: torch.fx.GraphModule,
|
||||
) -> dict[torch.device, torch.fx.Node]:
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
||||
for n in gm.graph.nodes:
|
||||
t = n.meta.get("val", None)
|
||||
|
|
@ -100,7 +102,7 @@ def get_device_node_mapping(gm: torch.fx.GraphModule):
|
|||
|
||||
|
||||
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
||||
aot_model: torch.fx.GraphModule, num_fixed
|
||||
aot_model: torch.fx.GraphModule, num_fixed: int
|
||||
) -> Optional[str]:
|
||||
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
|
||||
if not mutation_indices:
|
||||
|
|
@ -110,7 +112,7 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
|||
return get_mutation_stack_trace(placeholders, mutation_indices)
|
||||
|
||||
|
||||
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
|
||||
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
|
||||
if not config.cudagraph_backend_support_input_mutation:
|
||||
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
|
||||
aot_model, num_fixed
|
||||
|
|
@ -128,28 +130,35 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def get_device_index(gm) -> int:
|
||||
def get_device_index(gm: torch.fx.GraphModule) -> int:
|
||||
device = next(iter(get_device_node_mapping(gm)))
|
||||
assert device.type == "cuda"
|
||||
return device.index
|
||||
|
||||
|
||||
def get_stack_traces(gm) -> list[Optional[str]]:
|
||||
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
|
||||
output = output_node(gm)
|
||||
assert len(output.args) == 1
|
||||
args = output.args[0]
|
||||
if not hasattr(args, "__iter__"):
|
||||
return []
|
||||
return [
|
||||
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
||||
for arg in output.args[0]
|
||||
for arg in args # type: ignore[union-attr]
|
||||
]
|
||||
|
||||
|
||||
def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
|
||||
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
||||
|
||||
do_cudagraphs = BoxedBool(True)
|
||||
boxed_device_index = BoxedDeviceIndex(None)
|
||||
|
||||
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
|
||||
def forward_cudagraphs(
|
||||
aot_model: torch.fx.GraphModule,
|
||||
aot_inputs: list[Any],
|
||||
is_inference: bool = False,
|
||||
) -> Any:
|
||||
interp = boxed_nop(aot_model, aot_inputs)
|
||||
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
|
||||
if skip_msg := check_for_skip(aot_model, fixed):
|
||||
|
|
@ -166,15 +175,17 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
|||
range(fixed),
|
||||
device_index=boxed_device_index.value,
|
||||
is_backward=False,
|
||||
is_inference=False,
|
||||
is_inference=False, # Q: should forward is_inference here?
|
||||
stack_traces=get_stack_traces(aot_model),
|
||||
placeholders=get_placeholder_info(aot_model.graph),
|
||||
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
||||
)
|
||||
out._boxed_call = True
|
||||
out._boxed_call = True # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
def backward_cudagraphs(aot_model, aot_inputs):
|
||||
def backward_cudagraphs(
|
||||
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
|
||||
) -> Any:
|
||||
interp = boxed_nop(aot_model, aot_inputs)
|
||||
if not do_cudagraphs:
|
||||
return aot_model
|
||||
|
|
@ -182,20 +193,23 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
|||
fixed = count_tangents(aot_model)
|
||||
if skip_msg := check_for_skip(aot_model, fixed):
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
"skipping cudagraphs due to %s", skip_msg
|
||||
f"skipping cudagraphs due to {skip_msg}"
|
||||
)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
device_idx = boxed_device_index.value
|
||||
if device_idx is None:
|
||||
device_idx = 0 # Default to device 0 if not set
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_device_index.value, create_if_none_exists=False
|
||||
device_idx, create_if_none_exists=False
|
||||
)
|
||||
assert manager is not None
|
||||
|
||||
def fn(inputs):
|
||||
def fn(inputs: list[Any]) -> Any:
|
||||
manager.set_to_running_backward()
|
||||
return aot_model(inputs)
|
||||
|
||||
fn._boxed_call = True
|
||||
fn._boxed_call = True # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
out = cudagraphify_impl(
|
||||
|
|
@ -209,7 +223,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
|||
placeholders=get_placeholder_info(aot_model.graph),
|
||||
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
||||
)
|
||||
out._boxed_call = True
|
||||
out._boxed_call = True # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
aot_cudagraphs = aot_autograd(
|
||||
|
|
@ -225,13 +239,13 @@ class CudagraphsBackend:
|
|||
compiler_name = "cudagraphs"
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
def reset() -> None:
|
||||
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
||||
|
||||
reset_cudagraph_trees()
|
||||
|
||||
@staticmethod
|
||||
def __call__(model, inputs):
|
||||
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
|
||||
return cudagraphs(model, inputs)
|
||||
|
||||
|
||||
|
|
@ -240,7 +254,12 @@ class CudagraphsBackend:
|
|||
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
|
||||
|
||||
|
||||
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
||||
def cudagraphs_inner(
|
||||
model: Callable[..., Any],
|
||||
inputs: Sequence[Any],
|
||||
copy_outputs: bool = True,
|
||||
copy_inputs: bool = True,
|
||||
) -> Callable[..., Sequence[Any]]:
|
||||
"""This isn't registered as a backend, but is used in some benchmarks"""
|
||||
assert isinstance(inputs, (list, tuple))
|
||||
if copy_inputs:
|
||||
|
|
@ -265,7 +284,7 @@ def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
|||
if not isinstance(static_outputs, (list, tuple)):
|
||||
static_outputs = (static_outputs,)
|
||||
|
||||
def run(*new_inputs):
|
||||
def run(*new_inputs: Any) -> Sequence[Any]:
|
||||
assert len(static_inputs) == len(new_inputs)
|
||||
if copy_inputs:
|
||||
for dst, src in zip(static_inputs, new_inputs):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .utils import is_using_cudagraph_partition
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Sequence, Set as AbstractSet
|
||||
|
||||
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
|
|
@ -110,7 +110,8 @@ def format_default_skip_message(reason: str) -> str:
|
|||
|
||||
|
||||
def get_mutation_stack_trace(
|
||||
placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int]
|
||||
placeholders: Sequence[PlaceholderInfo],
|
||||
mutation_indices: Union[AbstractSet[int], Sequence[int]],
|
||||
) -> str:
|
||||
stack_trace: Optional[str] = ""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user