mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDAGraph] Graph Partition (#147648)
This PR implements cudagraph partition, following previous PR on inductor graph partition (#147038). Since there are many ops that cudagraph cannot support, this PR focuses on `cpu ops` and will add more partition rules in the next PR. ## Example ```python import torch torch._inductor.config.graph_partition = True def f(x, y): x1 = x + 1 y1 = y + 1 y_cpu = y1.cpu() + 1 z = x @ y return x1 + y1 + z + y_cpu.cuda() x, y = [torch.ones(2, 2, device="cuda") for _ in range(2)] x_cloned, y_cloned = [tmp.clone() for tmp in [x,y]] eager_out = f(x, y) f_compiled = torch.compile(f, mode="reduce-overhead") for _ in range(5): compiled_out = f_compiled(x_cloned, y_cloned) assert torch.allclose(eager_out, compiled_out) ``` w/o graph partition, we will skip cudagraph: ``` skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from : File "/home/boyuan/playground/cudagraph/graph_partition/graph_partition.py", line 9, in f y_cpu = y1.cpu() + 1 # 3 ``` w/ graph partition, we can see two cudagraphify under the same torch-compiled region:  ## Design PR #147038 splits `def call(args)` function into multiple `def partition_id(args)`. In this PR, we use `recursively_apply_fns()` to wrap each `partition_id()` function with `cudagraphify`. One major design point is, `cudagraphify` takes metadata such as static_input_idxs and we need to provide such metadata for each graph partition. However, we previously only have such metadata for the original graph instead of graph partitions. The [idea](https://github.com/pytorch/pytorch/pull/147038#discussion_r1964124800) is: - compute a mapping from the partition metadata (e.g., input/output idx) to the graph metadata, stored in `GraphPartitionMap`. - during post_compile, get the `CudagraphMetadata` for each partition based on the graph-level metadata and `GraphPartitionMap`, via `get_partition_cudagraph_metadata()`. - finally, in `cudagraph_partition_pos_compile`, we compute the `CudagraphMetadata` and apply cudagraphify for each graph via `recursively_apply_fns`. #### Q: How does it work with codecache? While we have multiple graph partitions, we still have 1 file and 1 `call` function for 1 dynamo graph. The major difference is we need to additionally load a `recursively_apply_fns()` for graph partition. We also add `partition_maps: Optional[list[GraphPartitionMap]]` to `CompiledFxGraph` so it will be serialized and could be deserialized later. ## Edge Case 1 PyTorch has an assumption on input/output orders. For example, backward inputs take saved tensors first and then tangents. In graph partition, we respect such orders via `graph_partition_signature_reorder`. ## Edge Case 2 Cudagraphifying `call` function gives 2 cudagraph managed tensors `buf0` and `primals_1`. However, cudagraphifying `partition_0` gives only 1 cudagraph managed tensor `buf0`. This leads to a semantic difference between cudagraph w/ and w/o graph partition. [full code comparison](https://www.internalfb.com/intern/diffing/?paste_number=1747654420)  To achieve the same semantic, we returns an input tensor as output if it is not freed in a graph partition. This allows more cudagraph managed tensors and is important for handling saved tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147648 Approved by: https://github.com/eellison
This commit is contained in:
parent
65d19a5699
commit
3e605fe46d
|
|
@ -2527,6 +2527,187 @@ if HAS_CUDA:
|
|||
eager_result = f(example_input)
|
||||
self.assertEqual(compiled_result, eager_result)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition(self):
|
||||
def f(x, y):
|
||||
x1 = x + 1
|
||||
y1 = y + 1
|
||||
y_cpu = y1.cpu() + 1
|
||||
z = x @ y
|
||||
return x1 + y1 + z + y_cpu.cuda()
|
||||
|
||||
x, y = [torch.randn(2, 2, device="cuda") for _ in range(2)]
|
||||
x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]]
|
||||
eager_out = f(x, y)
|
||||
|
||||
f_compiled = torch.compile(f, mode="reduce-overhead")
|
||||
|
||||
for _ in range(5):
|
||||
compiled_out = f_compiled(x_cloned, y_cloned)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# 2 graph partitions lead to 2 cudagraph
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 2)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
@torch._inductor.config.patch("triton.cudagraphs", False)
|
||||
def test_graph_partition_reduce_overhead_mode_effectiveness(self):
|
||||
# test that `mode="reduce-overhead"` still controls whether
|
||||
# cudagraph is applied. i.e., cudagraph is not applied when
|
||||
# mode="default".
|
||||
def f(x, y):
|
||||
x1 = x + 1
|
||||
y1 = y + 1
|
||||
y_cpu = y1.cpu() + 1
|
||||
z = x @ y
|
||||
return x1 + y1 + z + y_cpu.cuda()
|
||||
|
||||
x, y = [torch.randn(2, 2, device="cuda") for _ in range(2)]
|
||||
|
||||
f_compiled = torch.compile(f)
|
||||
for _ in range(5):
|
||||
_out = f_compiled(x, y)
|
||||
self.assertEqual(self.get_manager() is None, True)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_forward_backward(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = x + 1
|
||||
y1 = x + 2
|
||||
y_cpu = y1.cpu() + 1
|
||||
z = x @ y1
|
||||
inp = x1 + y1 + z + y_cpu.cuda()
|
||||
return self.linear(inp)
|
||||
|
||||
model = Mod().cuda()
|
||||
|
||||
input_data = torch.randn(16, 16).cuda()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
compiled_model = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
for _ in range(5):
|
||||
output = compiled_model(input_data)
|
||||
loss = criterion(output, torch.randint(0, 10, (16,)).cuda())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 2 graph partitions lead to 2 fwd cudagraphs and 2 bwd cudagraphs
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 4)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_cpu_only(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = x + 1
|
||||
y1 = x + 2
|
||||
y_cpu = y1 + 1
|
||||
z = x @ y1
|
||||
inp = x1 + y1 + z + y_cpu
|
||||
return self.linear(inp)
|
||||
|
||||
model = Mod().cpu()
|
||||
|
||||
input_data = torch.randn(16, 16).cpu()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
compiled_model = torch.compile(model, mode="default")
|
||||
|
||||
for _ in range(5):
|
||||
output = compiled_model(input_data)
|
||||
loss = criterion(output, torch.randint(0, 10, (16,)).cpu())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 0 cudagraph since all ops are on cpu
|
||||
self.assertEqual(self.get_manager() is None, True)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_forward_with_skipped_cudagraphed_backward(self):
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(x):
|
||||
return x * x * x
|
||||
|
||||
for _ in range(3):
|
||||
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
|
||||
out = foo(inp)
|
||||
|
||||
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
|
||||
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
|
||||
out.backward(back_inp)
|
||||
|
||||
# we should not have cudagraph'd the backwards
|
||||
new_id = self.get_manager().new_graph_id().id
|
||||
self.assertEqual(new_id, 1)
|
||||
|
||||
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_forward_backward_not_called(self):
|
||||
# tests saved tensor is handled correctly
|
||||
def foo(x, y):
|
||||
x_out = x * x * x
|
||||
torch._dynamo.graph_break()
|
||||
y_out = y * y * y
|
||||
return x_out, y_out
|
||||
|
||||
foo = torch.compile(foo, mode="reduce-overhead")
|
||||
|
||||
for _ in range(3):
|
||||
inps = [
|
||||
torch.rand([20, 20], requires_grad=True, device="cuda")
|
||||
for _ in range(2)
|
||||
]
|
||||
x_out, y_out = foo(inps[0], inps[1])
|
||||
x_out.sum().backward()
|
||||
|
||||
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
||||
|
||||
# we should not have cudagraph'd the y backward
|
||||
new_id = self.get_manager().new_graph_id().id
|
||||
self.assertEqual(new_id, 3)
|
||||
|
||||
@requires_multigpu()
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_multiple_devices_msg(self):
|
||||
def foo(x, y):
|
||||
return (x + 1, y + 2)
|
||||
|
||||
foo = torch.compile(foo, mode="reduce-overhead")
|
||||
for _ in range(3):
|
||||
foo(torch.ones([10], device="cuda"), torch.ones([20]))
|
||||
|
||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
||||
|
||||
with capture_stderr() as captured_output:
|
||||
for _ in range(3):
|
||||
foo(
|
||||
torch.ones([10], device="cuda:0"),
|
||||
torch.ones([10], device="cuda:1"),
|
||||
)
|
||||
|
||||
FileCheck().check("skipping cudagraphs due to multiple devices").run(
|
||||
captured_output[0]
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
||||
new_id = self.get_manager().new_graph_id().id
|
||||
self.assertEqual(new_id, 1)
|
||||
|
||||
class TestSAC(TestCase):
|
||||
def _make_observer_mode(self):
|
||||
class ObserverMode(TorchDispatchMode):
|
||||
|
|
|
|||
|
|
@ -14128,7 +14128,7 @@ if RUN_GPU:
|
|||
|
||||
if not config.cpp_wrapper:
|
||||
FileCheck().check("def partition_0(args):").check(
|
||||
"(buf0, buf1) = self.partitions[0](partition0_args)"
|
||||
"(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)"
|
||||
).check("recursively_apply_fns = runner.recursively_apply_fns").run(
|
||||
code[0]
|
||||
)
|
||||
|
|
@ -14237,6 +14237,26 @@ if RUN_GPU:
|
|||
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_fused_scheduler_node(self):
|
||||
def foo(x):
|
||||
x = x * 20
|
||||
x_alias = x[0]
|
||||
y = x * 10
|
||||
y_alias = y[0]
|
||||
torch._dynamo.graph_break()
|
||||
ind = torch.tensor(4, device=GPU_TYPE)
|
||||
x_alias2 = x[ind:]
|
||||
y_alias2 = y[ind:]
|
||||
return x, x_alias, x_alias2, y_alias, y_alias2
|
||||
|
||||
foo = torch.compile(foo)
|
||||
x = torch.rand([20, 20], device=GPU_TYPE)
|
||||
_, code = run_and_get_code(foo, x)
|
||||
|
||||
if not config.cpp_wrapper:
|
||||
FileCheck().check("def partition_0(args):").run(code[0])
|
||||
|
||||
class RNNTest(TestCase):
|
||||
device_type = GPU_TYPE
|
||||
|
||||
|
|
|
|||
|
|
@ -1151,7 +1151,7 @@ class _InProcessFxCompile(FxCompile):
|
|||
# not going to touch it for now
|
||||
|
||||
compiled_fn: Any
|
||||
|
||||
recursively_apply_fns = None
|
||||
with dynamo_timed(
|
||||
"GraphLowering.compile_to_fn", log_pt2_compile_event=True
|
||||
):
|
||||
|
|
@ -1203,7 +1203,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
],
|
||||
)
|
||||
else:
|
||||
compiled_fn = graph.compile_to_module().call
|
||||
compiled_module = graph.compile_to_module()
|
||||
compiled_fn = compiled_module.call
|
||||
recursively_apply_fns = getattr(
|
||||
compiled_module, "recursively_apply_fns", None
|
||||
)
|
||||
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
|
|
@ -1277,6 +1281,7 @@ class _InProcessFxCompile(FxCompile):
|
|||
graph_kwargs,
|
||||
inputs_to_check,
|
||||
boxed_forward_device_index,
|
||||
recursively_apply_fns,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
|||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, get_metrics_context
|
||||
from torch._inductor.utils import InputType
|
||||
from torch._inductor.utils import GraphPartitionMap, InputType
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
|
|
@ -167,6 +167,10 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
|
|||
def check_multiple_devices_or_any_cpu_nodes(
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node],
|
||||
) -> Optional[str]:
|
||||
if torch._inductor.config.graph_partition:
|
||||
# graph partition supports splitting on cpu op. So we can ignore cpu nodes.
|
||||
device_node_mapping.pop(torch.device("cpu"), None)
|
||||
|
||||
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
||||
msg = f"cpu device ({cpu_node.name})"
|
||||
if stack_trace := _get_use_stack_trace(cpu_node):
|
||||
|
|
@ -338,3 +342,70 @@ class CudagraphCachedInfo:
|
|||
placeholders: Sequence[PlaceholderInfo]
|
||||
stack_traces: list[Optional[str]]
|
||||
cudagraph_fail_reasons: list[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CudagraphMetadata:
|
||||
"""
|
||||
Metadata for recording a CUDA graph.
|
||||
"""
|
||||
|
||||
placeholders: Sequence[PlaceholderInfo]
|
||||
static_input_idxs: OrderedSet[int]
|
||||
mutated_input_idxs: OrderedSet[int]
|
||||
stack_traces: list[Optional[str]]
|
||||
constants: dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def get_partition_cudagraph_metadata(
|
||||
partition_map: GraphPartitionMap,
|
||||
metadata: CudagraphMetadata,
|
||||
) -> CudagraphMetadata:
|
||||
"""
|
||||
Convert the cudagraph metadata at the graph level to the graph partition level,
|
||||
given the graph partition info (i.e., mapping from partition input/output index
|
||||
to graph input/output index).
|
||||
"""
|
||||
|
||||
partition_placeholders = []
|
||||
partition_static_input_idxs: OrderedSet[int] = OrderedSet()
|
||||
partition_mutated_input_idxs: OrderedSet[int] = OrderedSet()
|
||||
for partition_input_idx, graph_input_idx in enumerate(
|
||||
partition_map.input_index_mapping
|
||||
):
|
||||
if graph_input_idx in metadata.static_input_idxs:
|
||||
partition_static_input_idxs.add(partition_input_idx)
|
||||
|
||||
if graph_input_idx in metadata.mutated_input_idxs:
|
||||
partition_mutated_input_idxs.add(partition_input_idx)
|
||||
|
||||
if graph_input_idx is not None:
|
||||
placeholder = metadata.placeholders[graph_input_idx]
|
||||
else:
|
||||
# create a dummy placeholder info since this partition input is not a graph input
|
||||
placeholder = PlaceholderInfo(
|
||||
name=f"partition_{partition_map.id}_placeholder_{partition_input_idx}",
|
||||
stack_trace=None,
|
||||
users=[],
|
||||
mutating_use_stack_trace=None,
|
||||
)
|
||||
partition_placeholders.append(placeholder)
|
||||
|
||||
partition_stack_traces = []
|
||||
for graph_output_idx in partition_map.output_index_mapping:
|
||||
if graph_output_idx is not None:
|
||||
partition_stack_traces.append(metadata.stack_traces[graph_output_idx])
|
||||
else:
|
||||
partition_stack_traces.append(None)
|
||||
|
||||
partition_constants = {
|
||||
name: metadata.constants[name] for name in partition_map.constant_names
|
||||
}
|
||||
|
||||
return CudagraphMetadata(
|
||||
partition_placeholders,
|
||||
partition_static_input_idxs,
|
||||
partition_mutated_input_idxs,
|
||||
partition_stack_traces,
|
||||
partition_constants,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ from .utils import (
|
|||
get_cloned_parameter_buffer_name,
|
||||
get_donated_idxs,
|
||||
get_sympy_Expr_dtype,
|
||||
GraphPartitionMap,
|
||||
is_same_tensor,
|
||||
maybe_get_suppress_shape_guards_ctx,
|
||||
normalize_name,
|
||||
|
|
@ -320,6 +321,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.graph_input_names: list[str] = []
|
||||
self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
|
||||
self.graph_inputs_original: dict[str, InputBuffer] = {}
|
||||
self.partition_maps: Optional[list[GraphPartitionMap]] = None
|
||||
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
|
||||
self.device_types: OrderedSet[str] = (
|
||||
const_module.device_types if const_module else OrderedSet()
|
||||
|
|
|
|||
|
|
@ -188,11 +188,15 @@ class GraphPartitionSignature:
|
|||
# we cannot get name from Expr.
|
||||
input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]]
|
||||
output_nodes: list[IRNode]
|
||||
|
||||
# mapping from partition input name to a boolean for whether deallocating it
|
||||
# in the partition function
|
||||
input_deallocation: dict[str, bool]
|
||||
skip_cudagraph: bool
|
||||
|
||||
# name of constants read/written by the graph partition
|
||||
constant_names: list[str]
|
||||
|
||||
|
||||
def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
|
||||
def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import dataclasses
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
|
@ -35,6 +36,8 @@ from torch._dynamo.utils import counters, get_runtime_metrics_context
|
|||
from torch._inductor.cudagraph_utils import (
|
||||
BoxedDeviceIndex,
|
||||
CudagraphCachedInfo,
|
||||
CudagraphMetadata,
|
||||
get_partition_cudagraph_metadata,
|
||||
get_placeholder_info,
|
||||
log_cudagraph_skip_and_bump_counter,
|
||||
)
|
||||
|
|
@ -42,6 +45,7 @@ from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
|||
from torch._inductor.utils import (
|
||||
align_inputs_from_check_idxs,
|
||||
BoxedBool,
|
||||
GraphPartitionMap,
|
||||
InputType,
|
||||
output_node,
|
||||
set_tracing_context_output_strides,
|
||||
|
|
@ -132,6 +136,48 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def maybe_handle_backward_generation(compiled_graph: CompiledFxGraph) -> None:
|
||||
assert compiled_graph.current_callable is not None
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
||||
# know we are we running the backward even if we will not run it in cudagraphs
|
||||
if is_backward and config.triton.cudagraph_trees:
|
||||
assert boxed_forward_device_index is not None
|
||||
assert boxed_forward_device_index.value is not None
|
||||
compiled_graph_callable = compiled_graph.current_callable
|
||||
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_forward_device_index.value, create_if_none_exists=False
|
||||
)
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
compiled_graph.current_callable = compiled_artifact
|
||||
|
||||
|
||||
def prepare_cudagraph_post_compile(
|
||||
compiled_graph: CompiledFxGraph, example_inputs: Sequence[InputType]
|
||||
) -> None:
|
||||
if not config.triton.cudagraph_trees:
|
||||
# Force specialize all inputs so that CUDA graphs will work
|
||||
for t in example_inputs:
|
||||
if isinstance(t, torch.SymInt):
|
||||
int(t) # guard
|
||||
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
if boxed_forward_device_index is not None and not is_inference and not is_backward:
|
||||
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
||||
|
||||
|
||||
def cudagraph_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
|
|
@ -147,7 +193,6 @@ def cudagraph_post_compile(
|
|||
assert compiled_graph.cudagraph_info is not None
|
||||
cached_info = compiled_graph.cudagraph_info
|
||||
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
|
||||
|
|
@ -157,18 +202,8 @@ def cudagraph_post_compile(
|
|||
|
||||
placeholders = cached_info.placeholders
|
||||
stack_traces = cached_info.stack_traces
|
||||
if not config.triton.cudagraph_trees:
|
||||
# Force specialize all inputs so that CUDA graphs will work
|
||||
for t in example_inputs:
|
||||
if isinstance(t, torch.SymInt):
|
||||
int(t) # guard
|
||||
|
||||
if (
|
||||
boxed_forward_device_index is not None
|
||||
and not is_inference
|
||||
and not is_backward
|
||||
):
|
||||
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
||||
prepare_cudagraph_post_compile(compiled_graph, example_inputs)
|
||||
|
||||
from .compile_fx import cudagraphify
|
||||
|
||||
|
|
@ -188,26 +223,7 @@ def cudagraph_post_compile(
|
|||
|
||||
else:
|
||||
BoxedBool.disable(cudagraphs)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
||||
# know we are we running the backward even if we will not run it in cudagraphs
|
||||
if is_backward and config.triton.cudagraph_trees:
|
||||
assert boxed_forward_device_index is not None
|
||||
assert boxed_forward_device_index.value is not None
|
||||
compiled_graph_callable = compiled_graph.current_callable
|
||||
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_forward_device_index.value, create_if_none_exists=False
|
||||
)
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
compiled_graph.current_callable = compiled_artifact
|
||||
maybe_handle_backward_generation(compiled_graph)
|
||||
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
# prefer better disable_cudagraphs_reason bc stack trace
|
||||
|
|
@ -222,6 +238,78 @@ def cudagraph_post_compile(
|
|||
)
|
||||
|
||||
|
||||
def cudagraph_partition_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
constants: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Cudagraphify each partition functions, which first prepares the necessary
|
||||
metadata and then applies the cudagraphify function to each partition.
|
||||
|
||||
Assuming all partition functions are cudagraphified and share the same order
|
||||
as `compiled_graph.partition_maps`. See [Note: Graph Partition Map for CUDAGraph].
|
||||
"""
|
||||
assert compiled_graph.cudagraph_info is not None
|
||||
cudagraph_fail_reasons = compiled_graph.cudagraph_info.cudagraph_fail_reasons
|
||||
|
||||
if (
|
||||
cudagraph_fail_reasons
|
||||
or compiled_graph.partition_maps is None
|
||||
or len(compiled_graph.partition_maps) == 0
|
||||
):
|
||||
# cudagraphify is not called if there are no partitions
|
||||
BoxedBool.disable(cudagraphs)
|
||||
maybe_handle_backward_generation(compiled_graph)
|
||||
return
|
||||
|
||||
from .compile_fx import cudagraphify
|
||||
|
||||
assert compiled_graph.current_callable is not None
|
||||
assert compiled_graph.recursively_apply_fns is not None
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ())
|
||||
mutated_input_idxs = compiled_graph.mutated_input_idxs
|
||||
device_index = next(iter(compiled_graph.device_idxs))
|
||||
|
||||
graph_metadata = CudagraphMetadata(
|
||||
compiled_graph.cudagraph_info.placeholders,
|
||||
static_input_idxs,
|
||||
mutated_input_idxs,
|
||||
compiled_graph.cudagraph_info.stack_traces,
|
||||
constants,
|
||||
)
|
||||
|
||||
prepare_cudagraph_post_compile(compiled_graph, example_inputs)
|
||||
|
||||
# cudagraphify each partition function, assuming every graph partition function
|
||||
# is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into
|
||||
# `call` function and not included in partition functions.
|
||||
cudagraphify_fns = []
|
||||
for partition_map in compiled_graph.partition_maps:
|
||||
partition_metadata = get_partition_cudagraph_metadata(
|
||||
partition_map,
|
||||
graph_metadata,
|
||||
)
|
||||
|
||||
cudagraphify_fn = partial(
|
||||
cudagraphify,
|
||||
static_input_idxs=tuple(partition_metadata.static_input_idxs),
|
||||
device_index=device_index,
|
||||
stack_traces=partition_metadata.stack_traces,
|
||||
is_backward=is_backward,
|
||||
is_inference=is_inference,
|
||||
constants=tuple(partition_metadata.constants.values()),
|
||||
placeholders=partition_metadata.placeholders,
|
||||
mutated_input_idxs=tuple(partition_metadata.mutated_input_idxs),
|
||||
)
|
||||
cudagraphify_fns.append(cudagraphify_fn)
|
||||
|
||||
compiled_graph.recursively_apply_fns(cudagraphify_fns)
|
||||
|
||||
|
||||
def maybe_realign_inputs(
|
||||
ran_cudagraphs: BoxedBool,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
|
|
@ -294,6 +382,7 @@ class CompiledFxGraph(OutputCode):
|
|||
"""
|
||||
|
||||
current_callable: Optional[Callable[..., Any]]
|
||||
recursively_apply_fns: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
cache_linemap: Optional[list[tuple[int, str]]]
|
||||
|
|
@ -316,6 +405,7 @@ class CompiledFxGraph(OutputCode):
|
|||
guards_expr: Optional[str]
|
||||
|
||||
cudagraph_info: Optional[CudagraphCachedInfo]
|
||||
partition_maps: Optional[list[GraphPartitionMap]]
|
||||
fx_kwargs: _CompileFxKwargs
|
||||
inputs_to_check: Sequence[int]
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
|
|
@ -338,8 +428,10 @@ class CompiledFxGraph(OutputCode):
|
|||
fx_kwargs: _CompileFxKwargs,
|
||||
inputs_to_check: Sequence[int],
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex],
|
||||
recursively_apply_fns: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
self.current_callable = current_callable
|
||||
self.recursively_apply_fns = recursively_apply_fns
|
||||
self.cache_key = graph.cache_key
|
||||
if graph.cache_path:
|
||||
with open(graph.cache_path) as f:
|
||||
|
|
@ -377,6 +469,7 @@ class CompiledFxGraph(OutputCode):
|
|||
self.counter_deltas = counter_deltas
|
||||
self.guards_expr = None
|
||||
self.cudagraph_info = None
|
||||
self.partition_maps = graph.partition_maps
|
||||
self.fx_kwargs = {}
|
||||
self.inputs_to_check = ()
|
||||
self.boxed_forward_device_index = None
|
||||
|
|
@ -492,12 +585,23 @@ class CompiledFxGraph(OutputCode):
|
|||
counters["inductor"]["cudagraph_skips"] += 1
|
||||
BoxedBool.disable(cudagraphs)
|
||||
else:
|
||||
cudagraph_post_compile(
|
||||
example_inputs,
|
||||
self,
|
||||
cudagraphs,
|
||||
constants.unwrap(self),
|
||||
)
|
||||
if config.graph_partition:
|
||||
# with graph_partition=True, we skip some cudagraph checks if it's supported
|
||||
# with partition. So we have to use cudagraph_partition_post_compile.
|
||||
cudagraph_partition_post_compile(
|
||||
example_inputs,
|
||||
self,
|
||||
cudagraphs,
|
||||
constants.unwrap(self),
|
||||
)
|
||||
else:
|
||||
cudagraph_post_compile(
|
||||
example_inputs,
|
||||
self,
|
||||
cudagraphs,
|
||||
constants.unwrap(self),
|
||||
)
|
||||
|
||||
inputs_to_check = self.inputs_to_check
|
||||
# cudagraphs could have been disabled from the earlier conditions
|
||||
# so we still need to realign inputs if that happens
|
||||
|
|
@ -516,6 +620,7 @@ class CompiledFxGraph(OutputCode):
|
|||
# TODO: This could be better if we're ever able to serialize compiled
|
||||
# models to disk.
|
||||
self.current_callable = None
|
||||
self.recursively_apply_fns = None
|
||||
|
||||
def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
|
|
@ -551,12 +656,16 @@ class CompiledFxGraph(OutputCode):
|
|||
"PyCodeCache.load_by_key_path",
|
||||
log_pt2_compile_event=True,
|
||||
):
|
||||
self.current_callable = PyCodeCache.load_by_key_path(
|
||||
code_cache = PyCodeCache.load_by_key_path(
|
||||
self.cache_key,
|
||||
artifact_path,
|
||||
self.cache_linemap,
|
||||
constants.unwrap(self),
|
||||
).call
|
||||
)
|
||||
self.current_callable = code_cache.call
|
||||
self.recursively_apply_fns = getattr(
|
||||
code_cache, "recursively_apply_fns", None
|
||||
)
|
||||
except OSError:
|
||||
log.error("Failed to load artifact: %s", artifact_path)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ from .utils import (
|
|||
get_device_tflops,
|
||||
get_dtype_size,
|
||||
get_gpu_dram_gbps,
|
||||
GraphPartitionMap,
|
||||
IndentedBuffer,
|
||||
is_collective,
|
||||
is_gpu,
|
||||
|
|
@ -3964,6 +3965,9 @@ class Scheduler:
|
|||
|
||||
def should_partition(self, node: BaseSchedulerNode) -> bool:
|
||||
"""Return True if we should partition the inductor graph on this node"""
|
||||
if isinstance(node, FusedSchedulerNode):
|
||||
return any(self.should_partition(snode) for snode in node.snodes)
|
||||
|
||||
if not node.is_gpu():
|
||||
return True
|
||||
|
||||
|
|
@ -3979,9 +3983,13 @@ class Scheduler:
|
|||
if getattr(node.node, "unbacked_bindings", None):
|
||||
return True
|
||||
|
||||
if hasattr(node.node, "layout") and any(
|
||||
isinstance(expr, sympy.Expr) and expr.free_symbols
|
||||
for expr in node.node.layout.size
|
||||
if (
|
||||
hasattr(node.node, "layout")
|
||||
and hasattr(node.node.layout, "size")
|
||||
and any(
|
||||
isinstance(expr, sympy.Expr) and expr.free_symbols
|
||||
for expr in node.node.layout.size
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
|
|
@ -4003,6 +4011,47 @@ class Scheduler:
|
|||
|
||||
return name_to_node
|
||||
|
||||
def compute_graph_partition_maps(
|
||||
self,
|
||||
signatures: list[GraphPartitionSignature],
|
||||
) -> None:
|
||||
"""
|
||||
computes a mapping from partition input/output indices to graph input/output
|
||||
indices for each partition.
|
||||
"""
|
||||
name_to_graph_input_index = {
|
||||
name: idx for idx, name in enumerate(V.graph.graph_inputs)
|
||||
}
|
||||
name_to_graph_output_index = {
|
||||
name: idx for idx, name in enumerate(V.graph.get_output_names())
|
||||
}
|
||||
|
||||
V.graph.partition_maps = []
|
||||
for partition_id, signature in enumerate(signatures):
|
||||
if signature.skip_cudagraph:
|
||||
# Note: [Graph Partition Map for CUDAGraph]
|
||||
# number of partition map should be the same as the number of generated
|
||||
# partition functions. This assumption will be used when cudagraphify
|
||||
# each partition function.
|
||||
continue
|
||||
|
||||
input_mapping = []
|
||||
for name in signature.input_nodes:
|
||||
input_mapping.append(name_to_graph_input_index.get(name))
|
||||
|
||||
output_mapping = []
|
||||
for node in signature.output_nodes:
|
||||
output_mapping.append(name_to_graph_output_index.get(node.get_name()))
|
||||
|
||||
V.graph.partition_maps.append(
|
||||
GraphPartitionMap(
|
||||
partition_id,
|
||||
input_mapping,
|
||||
output_mapping,
|
||||
signature.constant_names,
|
||||
)
|
||||
)
|
||||
|
||||
def get_graph_partition_signature(
|
||||
self, partitions: list[PartitionType], skip_cudagraphs: list[bool]
|
||||
) -> list[GraphPartitionSignature]:
|
||||
|
|
@ -4026,7 +4075,7 @@ class Scheduler:
|
|||
returned_output_names = output_names.intersection(unmet_output_names)
|
||||
|
||||
# all reads/writes are partition inputs except those generated
|
||||
# within the partition
|
||||
# within the partition and tensor constants
|
||||
read_writes = dependencies.ReadWrites.merge_list(
|
||||
[node.read_writes for node in partition]
|
||||
)
|
||||
|
|
@ -4049,15 +4098,35 @@ class Scheduler:
|
|||
for name in partition_input_names
|
||||
if name in name_to_node
|
||||
}
|
||||
|
||||
# if an input tensor is not freed in the partition function, it should
|
||||
# also be returned as an output. This brings benefits to cudagraph
|
||||
# since the returned output tensor is a cudagraph managed tensor with
|
||||
# a static tensor address.
|
||||
extra_output_names = [
|
||||
name
|
||||
for name in partition_input_names
|
||||
if name in name_to_node and name not in buffer_names_to_free
|
||||
]
|
||||
|
||||
returned_output_names.update(extra_output_names)
|
||||
|
||||
output_nodes = [name_to_node[name] for name in returned_output_names]
|
||||
signatures.append(
|
||||
GraphPartitionSignature(
|
||||
input_nodes,
|
||||
output_nodes,
|
||||
input_deallocation,
|
||||
skip_cudagraph,
|
||||
)
|
||||
|
||||
constant_names = [
|
||||
name for name in partition_input_names if name in V.graph.constants
|
||||
]
|
||||
|
||||
partition_signature = GraphPartitionSignature(
|
||||
input_nodes,
|
||||
output_nodes,
|
||||
input_deallocation,
|
||||
skip_cudagraph,
|
||||
constant_names,
|
||||
)
|
||||
|
||||
signatures.append(partition_signature)
|
||||
|
||||
unmet_output_names = partition_input_names.union(
|
||||
unmet_output_names - returned_output_names
|
||||
)
|
||||
|
|
@ -4090,9 +4159,12 @@ class Scheduler:
|
|||
partitions.append(cur_partition)
|
||||
skip_cudagraphs.append(skip_cudagraph)
|
||||
|
||||
return partitions, self.get_graph_partition_signature(
|
||||
signatures = self.get_graph_partition_signature(
|
||||
partitions=partitions, skip_cudagraphs=skip_cudagraphs
|
||||
)
|
||||
self.compute_graph_partition_maps(signatures)
|
||||
|
||||
return partitions, signatures
|
||||
|
||||
def codegen(self) -> None:
|
||||
with dynamo_timed("Scheduler.codegen"):
|
||||
|
|
@ -4149,6 +4221,13 @@ class Scheduler:
|
|||
num_partitions = next(self._graph_partition_counter)
|
||||
V.graph.wrapper_code.set_all_partition_names(num_partitions)
|
||||
|
||||
# See [Note: Graph Partition Map for CUDAGraph]
|
||||
if num_partitions > 0:
|
||||
assert V.graph.partition_maps is not None
|
||||
assert num_partitions == len(V.graph.partition_maps), (
|
||||
f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}"
|
||||
)
|
||||
|
||||
def _codegen(self, nodes: list[BaseSchedulerNode]) -> None:
|
||||
if config.check_stack_no_cycles_TESTING_ONLY:
|
||||
import torch._dynamo.convert_frame
|
||||
|
|
|
|||
|
|
@ -151,6 +151,24 @@ class align(sympy.Function):
|
|||
return value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GraphPartitionMap:
|
||||
"""
|
||||
Mapping from the partition info (e.g., input/output) to the graph info
|
||||
"""
|
||||
|
||||
# a unique id of graph partition
|
||||
id: int
|
||||
|
||||
# map partition input/output indices to graph input/output indices. None indicates
|
||||
# a partition input/output is not a graph input/output.
|
||||
input_index_mapping: list[Optional[int]]
|
||||
output_index_mapping: list[Optional[int]]
|
||||
|
||||
# name of constants read/written by the graph partition
|
||||
constant_names: list[str]
|
||||
|
||||
|
||||
def do_bench_using_profiling(
|
||||
fn: Callable[[], Any], warmup: int = 25, rep: int = 100
|
||||
) -> float:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user