[CUDAGraph] Graph Partition (#147648)

This PR implements cudagraph partition, following previous PR on inductor graph partition (#147038). Since there are many ops that cudagraph cannot support, this PR focuses on `cpu ops` and will add more partition rules in the next PR.

## Example
```python
import torch

torch._inductor.config.graph_partition = True

def f(x, y):
    x1 = x + 1
    y1 = y + 1
    y_cpu = y1.cpu() + 1
    z = x @ y
    return x1 + y1 + z + y_cpu.cuda()

x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)]
x_cloned, y_cloned = [tmp.clone() for tmp in [x,y]]
eager_out = f(x, y)

f_compiled = torch.compile(f, mode="reduce-overhead")

for _ in range(5):
    compiled_out = f_compiled(x_cloned, y_cloned)
    assert torch.allclose(eager_out, compiled_out)
```

w/o graph partition, we will skip cudagraph:
```
skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from :
   File "/home/boyuan/playground/cudagraph/graph_partition/graph_partition.py", line 9, in f
    y_cpu = y1.cpu() + 1 # 3
```

w/ graph partition, we can see two cudagraphify under the same torch-compiled region:
![image](https://github.com/user-attachments/assets/4e22d428-2687-433d-b92a-0814a2201b25)

## Design

PR #147038 splits `def call(args)` function into multiple `def partition_id(args)`. In this PR, we use `recursively_apply_fns()` to wrap each `partition_id()` function with `cudagraphify`. One major design point is, `cudagraphify` takes metadata such as static_input_idxs and we need to provide such metadata for each graph partition. However, we previously only have such metadata for the original graph instead of graph partitions.

The [idea](https://github.com/pytorch/pytorch/pull/147038#discussion_r1964124800) is:
- compute a mapping from the partition metadata (e.g., input/output idx) to the graph metadata, stored in `GraphPartitionMap`.
- during post_compile, get the `CudagraphMetadata` for each partition based on the graph-level metadata and `GraphPartitionMap`, via `get_partition_cudagraph_metadata()`.
- finally, in `cudagraph_partition_pos_compile`, we compute the `CudagraphMetadata` and apply cudagraphify for each graph via `recursively_apply_fns`.

#### Q: How does it work with codecache?

While we have multiple graph partitions, we still have 1 file and 1 `call` function for 1 dynamo graph. The major difference is we need to additionally load a `recursively_apply_fns()` for graph partition. We also add `partition_maps: Optional[list[GraphPartitionMap]]` to `CompiledFxGraph` so it will be serialized and could be deserialized later.

## Edge Case 1
PyTorch has an assumption on input/output orders. For example, backward inputs take saved tensors first and then tangents. In graph partition, we respect such orders via `graph_partition_signature_reorder`.

## Edge Case 2
Cudagraphifying `call` function gives 2 cudagraph managed tensors `buf0` and `primals_1`. However, cudagraphifying `partition_0` gives only 1 cudagraph managed tensor `buf0`. This leads to a semantic difference between cudagraph w/ and w/o graph partition. [full code comparison](https://www.internalfb.com/intern/diffing/?paste_number=1747654420)

![image](https://github.com/user-attachments/assets/03d08ce0-f1d1-4d1d-8432-805a07e1dd40)

To achieve the same semantic, we returns an input tensor as output if it is not freed in a graph partition. This allows more cudagraph managed tensors and is important for handling saved tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147648
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng 2025-03-13 16:00:21 +00:00 committed by PyTorch MergeBot
parent 65d19a5699
commit 3e605fe46d
9 changed files with 545 additions and 56 deletions

View File

@ -2527,6 +2527,187 @@ if HAS_CUDA:
eager_result = f(example_input)
self.assertEqual(compiled_result, eager_result)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition(self):
def f(x, y):
x1 = x + 1
y1 = y + 1
y_cpu = y1.cpu() + 1
z = x @ y
return x1 + y1 + z + y_cpu.cuda()
x, y = [torch.randn(2, 2, device="cuda") for _ in range(2)]
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
eager_out = f(x, y)
f_compiled = torch.compile(f, mode="reduce-overhead")
for _ in range(5):
compiled_out = f_compiled(x_cloned, y_cloned)
self.assertEqual(eager_out, compiled_out)
# 2 graph partitions lead to 2 cudagraph
self.assertEqual(self.get_manager().new_graph_id().id, 2)
@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("triton.cudagraphs", False)
def test_graph_partition_reduce_overhead_mode_effectiveness(self):
# test that `mode="reduce-overhead"` still controls whether
# cudagraph is applied. i.e., cudagraph is not applied when
# mode="default".
def f(x, y):
x1 = x + 1
y1 = y + 1
y_cpu = y1.cpu() + 1
z = x @ y
return x1 + y1 + z + y_cpu.cuda()
x, y = [torch.randn(2, 2, device="cuda") for _ in range(2)]
f_compiled = torch.compile(f)
for _ in range(5):
_out = f_compiled(x, y)
self.assertEqual(self.get_manager() is None, True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_forward_backward(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
x1 = x + 1
y1 = x + 2
y_cpu = y1.cpu() + 1
z = x @ y1
inp = x1 + y1 + z + y_cpu.cuda()
return self.linear(inp)
model = Mod().cuda()
input_data = torch.randn(16, 16).cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model, mode="reduce-overhead")
for _ in range(5):
output = compiled_model(input_data)
loss = criterion(output, torch.randint(0, 10, (16,)).cuda())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 2 graph partitions lead to 2 fwd cudagraphs and 2 bwd cudagraphs
self.assertEqual(self.get_manager().new_graph_id().id, 4)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_cpu_only(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
x1 = x + 1
y1 = x + 2
y_cpu = y1 + 1
z = x @ y1
inp = x1 + y1 + z + y_cpu
return self.linear(inp)
model = Mod().cpu()
input_data = torch.randn(16, 16).cpu()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model, mode="default")
for _ in range(5):
output = compiled_model(input_data)
loss = criterion(output, torch.randint(0, 10, (16,)).cpu())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 0 cudagraph since all ops are on cpu
self.assertEqual(self.get_manager() is None, True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_forward_with_skipped_cudagraphed_backward(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
for _ in range(3):
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
# we should not have cudagraph'd the backwards
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 1)
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_forward_backward_not_called(self):
# tests saved tensor is handled correctly
def foo(x, y):
x_out = x * x * x
torch._dynamo.graph_break()
y_out = y * y * y
return x_out, y_out
foo = torch.compile(foo, mode="reduce-overhead")
for _ in range(3):
inps = [
torch.rand([20, 20], requires_grad=True, device="cuda")
for _ in range(2)
]
x_out, y_out = foo(inps[0], inps[1])
x_out.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
# we should not have cudagraph'd the y backward
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 3)
@requires_multigpu()
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_multiple_devices_msg(self):
def foo(x, y):
return (x + 1, y + 2)
foo = torch.compile(foo, mode="reduce-overhead")
for _ in range(3):
foo(torch.ones([10], device="cuda"), torch.ones([20]))
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
with capture_stderr() as captured_output:
for _ in range(3):
foo(
torch.ones([10], device="cuda:0"),
torch.ones([10], device="cuda:1"),
)
FileCheck().check("skipping cudagraphs due to multiple devices").run(
captured_output[0]
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 1)
class TestSAC(TestCase):
def _make_observer_mode(self):
class ObserverMode(TorchDispatchMode):

View File

@ -14128,7 +14128,7 @@ if RUN_GPU:
if not config.cpp_wrapper:
FileCheck().check("def partition_0(args):").check(
"(buf0, buf1) = self.partitions[0](partition0_args)"
"(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)"
).check("recursively_apply_fns = runner.recursively_apply_fns").run(
code[0]
)
@ -14237,6 +14237,26 @@ if RUN_GPU:
self.assertEqual(eager_out, compiled_out)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_fused_scheduler_node(self):
def foo(x):
x = x * 20
x_alias = x[0]
y = x * 10
y_alias = y[0]
torch._dynamo.graph_break()
ind = torch.tensor(4, device=GPU_TYPE)
x_alias2 = x[ind:]
y_alias2 = y[ind:]
return x, x_alias, x_alias2, y_alias, y_alias2
foo = torch.compile(foo)
x = torch.rand([20, 20], device=GPU_TYPE)
_, code = run_and_get_code(foo, x)
if not config.cpp_wrapper:
FileCheck().check("def partition_0(args):").run(code[0])
class RNNTest(TestCase):
device_type = GPU_TYPE

View File

@ -1151,7 +1151,7 @@ class _InProcessFxCompile(FxCompile):
# not going to touch it for now
compiled_fn: Any
recursively_apply_fns = None
with dynamo_timed(
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
):
@ -1203,7 +1203,11 @@ class _InProcessFxCompile(FxCompile):
],
)
else:
compiled_fn = graph.compile_to_module().call
compiled_module = graph.compile_to_module()
compiled_fn = compiled_module.call
recursively_apply_fns = getattr(
compiled_module, "recursively_apply_fns", None
)
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
@ -1277,6 +1281,7 @@ class _InProcessFxCompile(FxCompile):
graph_kwargs,
inputs_to_check,
boxed_forward_device_index,
recursively_apply_fns,
)

View File

@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.utils import counters, get_metrics_context
from torch._inductor.utils import InputType
from torch._inductor.utils import GraphPartitionMap, InputType
from torch.utils._ordered_set import OrderedSet
@ -167,6 +167,10 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]:
if torch._inductor.config.graph_partition:
# graph partition supports splitting on cpu op. So we can ignore cpu nodes.
device_node_mapping.pop(torch.device("cpu"), None)
if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})"
if stack_trace := _get_use_stack_trace(cpu_node):
@ -338,3 +342,70 @@ class CudagraphCachedInfo:
placeholders: Sequence[PlaceholderInfo]
stack_traces: list[Optional[str]]
cudagraph_fail_reasons: list[str]
@dataclasses.dataclass(frozen=True)
class CudagraphMetadata:
"""
Metadata for recording a CUDA graph.
"""
placeholders: Sequence[PlaceholderInfo]
static_input_idxs: OrderedSet[int]
mutated_input_idxs: OrderedSet[int]
stack_traces: list[Optional[str]]
constants: dict[str, torch.Tensor]
def get_partition_cudagraph_metadata(
partition_map: GraphPartitionMap,
metadata: CudagraphMetadata,
) -> CudagraphMetadata:
"""
Convert the cudagraph metadata at the graph level to the graph partition level,
given the graph partition info (i.e., mapping from partition input/output index
to graph input/output index).
"""
partition_placeholders = []
partition_static_input_idxs: OrderedSet[int] = OrderedSet()
partition_mutated_input_idxs: OrderedSet[int] = OrderedSet()
for partition_input_idx, graph_input_idx in enumerate(
partition_map.input_index_mapping
):
if graph_input_idx in metadata.static_input_idxs:
partition_static_input_idxs.add(partition_input_idx)
if graph_input_idx in metadata.mutated_input_idxs:
partition_mutated_input_idxs.add(partition_input_idx)
if graph_input_idx is not None:
placeholder = metadata.placeholders[graph_input_idx]
else:
# create a dummy placeholder info since this partition input is not a graph input
placeholder = PlaceholderInfo(
name=f"partition_{partition_map.id}_placeholder_{partition_input_idx}",
stack_trace=None,
users=[],
mutating_use_stack_trace=None,
)
partition_placeholders.append(placeholder)
partition_stack_traces = []
for graph_output_idx in partition_map.output_index_mapping:
if graph_output_idx is not None:
partition_stack_traces.append(metadata.stack_traces[graph_output_idx])
else:
partition_stack_traces.append(None)
partition_constants = {
name: metadata.constants[name] for name in partition_map.constant_names
}
return CudagraphMetadata(
partition_placeholders,
partition_static_input_idxs,
partition_mutated_input_idxs,
partition_stack_traces,
partition_constants,
)

View File

@ -97,6 +97,7 @@ from .utils import (
get_cloned_parameter_buffer_name,
get_donated_idxs,
get_sympy_Expr_dtype,
GraphPartitionMap,
is_same_tensor,
maybe_get_suppress_shape_guards_ctx,
normalize_name,
@ -320,6 +321,7 @@ class GraphLowering(torch.fx.Interpreter):
self.graph_input_names: list[str] = []
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
self.graph_inputs_original: dict[str, InputBuffer] = {}
self.partition_maps: Optional[list[GraphPartitionMap]] = None
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
self.device_types: OrderedSet[str] = (
const_module.device_types if const_module else OrderedSet()

View File

@ -188,11 +188,15 @@ class GraphPartitionSignature:
# we cannot get name from Expr.
input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]]
output_nodes: list[IRNode]
# mapping from partition input name to a boolean for whether deallocating it
# in the partition function
input_deallocation: dict[str, bool]
skip_cudagraph: bool
# name of constants read/written by the graph partition
constant_names: list[str]
def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None:

View File

@ -26,6 +26,7 @@ import dataclasses
import logging
import os
import re
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
@ -35,6 +36,8 @@ from torch._dynamo.utils import counters, get_runtime_metrics_context
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
CudagraphCachedInfo,
CudagraphMetadata,
get_partition_cudagraph_metadata,
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
@ -42,6 +45,7 @@ from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.utils import (
align_inputs_from_check_idxs,
BoxedBool,
GraphPartitionMap,
InputType,
output_node,
set_tracing_context_output_strides,
@ -132,6 +136,48 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
return False
def maybe_handle_backward_generation(compiled_graph: CompiledFxGraph) -> None:
assert compiled_graph.current_callable is not None
is_backward = compiled_graph.fx_kwargs["is_backward"]
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
def prepare_cudagraph_post_compile(
compiled_graph: CompiledFxGraph, example_inputs: Sequence[InputType]
) -> None:
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if boxed_forward_device_index is not None and not is_inference and not is_backward:
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
@ -147,7 +193,6 @@ def cudagraph_post_compile(
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
@ -157,18 +202,8 @@ def cudagraph_post_compile(
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
prepare_cudagraph_post_compile(compiled_graph, example_inputs)
from .compile_fx import cudagraphify
@ -188,26 +223,7 @@ def cudagraph_post_compile(
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
maybe_handle_backward_generation(compiled_graph)
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
@ -222,6 +238,78 @@ def cudagraph_post_compile(
)
def cudagraph_partition_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
constants: dict[str, torch.Tensor],
) -> None:
"""
Cudagraphify each partition functions, which first prepares the necessary
metadata and then applies the cudagraphify function to each partition.
Assuming all partition functions are cudagraphified and share the same order
as `compiled_graph.partition_maps`. See [Note: Graph Partition Map for CUDAGraph].
"""
assert compiled_graph.cudagraph_info is not None
cudagraph_fail_reasons = compiled_graph.cudagraph_info.cudagraph_fail_reasons
if (
cudagraph_fail_reasons
or compiled_graph.partition_maps is None
or len(compiled_graph.partition_maps) == 0
):
# cudagraphify is not called if there are no partitions
BoxedBool.disable(cudagraphs)
maybe_handle_backward_generation(compiled_graph)
return
from .compile_fx import cudagraphify
assert compiled_graph.current_callable is not None
assert compiled_graph.recursively_apply_fns is not None
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ())
mutated_input_idxs = compiled_graph.mutated_input_idxs
device_index = next(iter(compiled_graph.device_idxs))
graph_metadata = CudagraphMetadata(
compiled_graph.cudagraph_info.placeholders,
static_input_idxs,
mutated_input_idxs,
compiled_graph.cudagraph_info.stack_traces,
constants,
)
prepare_cudagraph_post_compile(compiled_graph, example_inputs)
# cudagraphify each partition function, assuming every graph partition function
# is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into
# `call` function and not included in partition functions.
cudagraphify_fns = []
for partition_map in compiled_graph.partition_maps:
partition_metadata = get_partition_cudagraph_metadata(
partition_map,
graph_metadata,
)
cudagraphify_fn = partial(
cudagraphify,
static_input_idxs=tuple(partition_metadata.static_input_idxs),
device_index=device_index,
stack_traces=partition_metadata.stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(partition_metadata.constants.values()),
placeholders=partition_metadata.placeholders,
mutated_input_idxs=tuple(partition_metadata.mutated_input_idxs),
)
cudagraphify_fns.append(cudagraphify_fn)
compiled_graph.recursively_apply_fns(cudagraphify_fns)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
@ -294,6 +382,7 @@ class CompiledFxGraph(OutputCode):
"""
current_callable: Optional[Callable[..., Any]]
recursively_apply_fns: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[list[tuple[int, str]]]
@ -316,6 +405,7 @@ class CompiledFxGraph(OutputCode):
guards_expr: Optional[str]
cudagraph_info: Optional[CudagraphCachedInfo]
partition_maps: Optional[list[GraphPartitionMap]]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
@ -338,8 +428,10 @@ class CompiledFxGraph(OutputCode):
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
boxed_forward_device_index: Optional[BoxedDeviceIndex],
recursively_apply_fns: Optional[Callable[..., Any]] = None,
) -> None:
self.current_callable = current_callable
self.recursively_apply_fns = recursively_apply_fns
self.cache_key = graph.cache_key
if graph.cache_path:
with open(graph.cache_path) as f:
@ -377,6 +469,7 @@ class CompiledFxGraph(OutputCode):
self.counter_deltas = counter_deltas
self.guards_expr = None
self.cudagraph_info = None
self.partition_maps = graph.partition_maps
self.fx_kwargs = {}
self.inputs_to_check = ()
self.boxed_forward_device_index = None
@ -492,12 +585,23 @@ class CompiledFxGraph(OutputCode):
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
self,
cudagraphs,
constants.unwrap(self),
)
if config.graph_partition:
# with graph_partition=True, we skip some cudagraph checks if it's supported
# with partition. So we have to use cudagraph_partition_post_compile.
cudagraph_partition_post_compile(
example_inputs,
self,
cudagraphs,
constants.unwrap(self),
)
else:
cudagraph_post_compile(
example_inputs,
self,
cudagraphs,
constants.unwrap(self),
)
inputs_to_check = self.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
@ -516,6 +620,7 @@ class CompiledFxGraph(OutputCode):
# TODO: This could be better if we're ever able to serialize compiled
# models to disk.
self.current_callable = None
self.recursively_apply_fns = None
def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
from torch._dynamo.utils import counters, dynamo_timed
@ -551,12 +656,16 @@ class CompiledFxGraph(OutputCode):
"PyCodeCache.load_by_key_path",
log_pt2_compile_event=True,
):
self.current_callable = PyCodeCache.load_by_key_path(
code_cache = PyCodeCache.load_by_key_path(
self.cache_key,
artifact_path,
self.cache_linemap,
constants.unwrap(self),
).call
)
self.current_callable = code_cache.call
self.recursively_apply_fns = getattr(
code_cache, "recursively_apply_fns", None
)
except OSError:
log.error("Failed to load artifact: %s", artifact_path)
raise

View File

@ -57,6 +57,7 @@ from .utils import (
get_device_tflops,
get_dtype_size,
get_gpu_dram_gbps,
GraphPartitionMap,
IndentedBuffer,
is_collective,
is_gpu,
@ -3964,6 +3965,9 @@ class Scheduler:
def should_partition(self, node: BaseSchedulerNode) -> bool:
"""Return True if we should partition the inductor graph on this node"""
if isinstance(node, FusedSchedulerNode):
return any(self.should_partition(snode) for snode in node.snodes)
if not node.is_gpu():
return True
@ -3979,9 +3983,13 @@ class Scheduler:
if getattr(node.node, "unbacked_bindings", None):
return True
if hasattr(node.node, "layout") and any(
isinstance(expr, sympy.Expr) and expr.free_symbols
for expr in node.node.layout.size
if (
hasattr(node.node, "layout")
and hasattr(node.node.layout, "size")
and any(
isinstance(expr, sympy.Expr) and expr.free_symbols
for expr in node.node.layout.size
)
):
return True
@ -4003,6 +4011,47 @@ class Scheduler:
return name_to_node
def compute_graph_partition_maps(
self,
signatures: list[GraphPartitionSignature],
) -> None:
"""
computes a mapping from partition input/output indices to graph input/output
indices for each partition.
"""
name_to_graph_input_index = {
name: idx for idx, name in enumerate(V.graph.graph_inputs)
}
name_to_graph_output_index = {
name: idx for idx, name in enumerate(V.graph.get_output_names())
}
V.graph.partition_maps = []
for partition_id, signature in enumerate(signatures):
if signature.skip_cudagraph:
# Note: [Graph Partition Map for CUDAGraph]
# number of partition map should be the same as the number of generated
# partition functions. This assumption will be used when cudagraphify
# each partition function.
continue
input_mapping = []
for name in signature.input_nodes:
input_mapping.append(name_to_graph_input_index.get(name))
output_mapping = []
for node in signature.output_nodes:
output_mapping.append(name_to_graph_output_index.get(node.get_name()))
V.graph.partition_maps.append(
GraphPartitionMap(
partition_id,
input_mapping,
output_mapping,
signature.constant_names,
)
)
def get_graph_partition_signature(
self, partitions: list[PartitionType], skip_cudagraphs: list[bool]
) -> list[GraphPartitionSignature]:
@ -4026,7 +4075,7 @@ class Scheduler:
returned_output_names = output_names.intersection(unmet_output_names)
# all reads/writes are partition inputs except those generated
# within the partition
# within the partition and tensor constants
read_writes = dependencies.ReadWrites.merge_list(
[node.read_writes for node in partition]
)
@ -4049,15 +4098,35 @@ class Scheduler:
for name in partition_input_names
if name in name_to_node
}
# if an input tensor is not freed in the partition function, it should
# also be returned as an output. This brings benefits to cudagraph
# since the returned output tensor is a cudagraph managed tensor with
# a static tensor address.
extra_output_names = [
name
for name in partition_input_names
if name in name_to_node and name not in buffer_names_to_free
]
returned_output_names.update(extra_output_names)
output_nodes = [name_to_node[name] for name in returned_output_names]
signatures.append(
GraphPartitionSignature(
input_nodes,
output_nodes,
input_deallocation,
skip_cudagraph,
)
constant_names = [
name for name in partition_input_names if name in V.graph.constants
]
partition_signature = GraphPartitionSignature(
input_nodes,
output_nodes,
input_deallocation,
skip_cudagraph,
constant_names,
)
signatures.append(partition_signature)
unmet_output_names = partition_input_names.union(
unmet_output_names - returned_output_names
)
@ -4090,9 +4159,12 @@ class Scheduler:
partitions.append(cur_partition)
skip_cudagraphs.append(skip_cudagraph)
return partitions, self.get_graph_partition_signature(
signatures = self.get_graph_partition_signature(
partitions=partitions, skip_cudagraphs=skip_cudagraphs
)
self.compute_graph_partition_maps(signatures)
return partitions, signatures
def codegen(self) -> None:
with dynamo_timed("Scheduler.codegen"):
@ -4149,6 +4221,13 @@ class Scheduler:
num_partitions = next(self._graph_partition_counter)
V.graph.wrapper_code.set_all_partition_names(num_partitions)
# See [Note: Graph Partition Map for CUDAGraph]
if num_partitions > 0:
assert V.graph.partition_maps is not None
assert num_partitions == len(V.graph.partition_maps), (
f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}"
)
def _codegen(self, nodes: list[BaseSchedulerNode]) -> None:
if config.check_stack_no_cycles_TESTING_ONLY:
import torch._dynamo.convert_frame

View File

@ -151,6 +151,24 @@ class align(sympy.Function):
return value
@dataclasses.dataclass(frozen=True)
class GraphPartitionMap:
"""
Mapping from the partition info (e.g., input/output) to the graph info
"""
# a unique id of graph partition
id: int
# map partition input/output indices to graph input/output indices. None indicates
# a partition input/output is not a graph input/output.
input_index_mapping: list[Optional[int]]
output_index_mapping: list[Optional[int]]
# name of constants read/written by the graph partition
constant_names: list[str]
def do_bench_using_profiling(
fn: Callable[[], Any], warmup: int = 25, rep: int = 100
) -> float: