pytorch/torch/_dynamo/backends/cudagraphs.py
Boyuan Feng 9a71d12d92 [CUDAGraphTree] Support mutated inputs from prior cudagraph pool (#123231)
# PR
This PR supports mutating inputs in cudagraph trees, if these inputs are outputs from previous cudagraph. Please check #121861 for more details.

# Note on Optimistic Mutation Check
To determine whether applying cudagraph, we need to check input mutations, falling into four categories: a) no mutation, b) mutation on parameters/buffers, c) mutation on cudagraph recorded tensors, d) mutation on non-cudagraph recorded tensors. We can apply cudagraph for type a,b,c but cannot for type d. This input mutation types depends on function, current_node, and inputs.

Since `check_for_mutation` is slow, there is a trade-off on making type c or d faster.
- To make type d) faster, we want to `check_for_mutation` and call eager function early. However, this adds unnecessary overhead to type a, b, c due to the extra check.
- To make type c) faster, we want to skip `check_for_mutation` at the beginning and only `check_for_mutation` before `record_function` for a new function. This removes the overhead of `check_for_mutation` for type a, b, c. However, this adds extra overhead to type d due to `check_invariants` for all children nodes.

Instead, we design optimistic mutation check. The assumption is that, given a function and a node, the input mutation types usually remain the same across inputs. So, if we have ever detect a function on a node with type d, we will never detect it as type c. The detailed design is:
- [Slow Path] On the first invocation of a function on a node, we run `check_for_mutation` once and cache the input mutation type as `non_cudagraph_managed_mutation[node_id][func_id]`.
- [Fast Path] On the subsequent invocations of a function on a node, we skip `check_for_mutation`. For `non_cudagraph_managed_mutation[node_id][func_id]` as true, we directly call eager function. Otherwise, we `check_variants` and call cudagraph function.
- [Slow Path] Before `record_function`, we run `check_for_mutation` again.

**Q1: Would there be overhead for type a,b,c,d?**
A: No. We only check input mutation types for the first invocation of a function on a node.

**Q2: If a function happens to be type c during the first invocation on a node, could we detect it as type d in the future?**
A: Yes. This is done by `check_invariants` and guarantees the correctness.

**Q3: If a function happens to be type d during the first invocation on a node, could it still be recognized as type c in the future?**
A: No. But this should happen rarely according to our assumption. In the rare case that it happens, there would not be any correctness issues and the performance is the same as the eager (or inductor optimized) function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123231
Approved by: https://github.com/eellison
2024-04-19 10:32:12 +00:00

234 lines
7.4 KiB
Python

# mypy: ignore-errors
import functools
import operator
from collections import defaultdict
from typing import Dict, List, Optional
import torch
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
get_placeholders,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
has_incompatible_cudagraph_ops,
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend
perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if has_incompatible_cudagraph_ops(aot_model):
return "skipping cudagraphs due to incompatible op"
return None
def get_device_index(gm) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> List[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
]
def cudagraphs(dynamo_model, dynamo_inputs):
from torch._inductor.cudagraph_trees import cudagraphify_impl
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(do_cudagraphs)
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
return interp
boxed_device_index.set(get_device_index(aot_model))
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholders(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
def backward_cudagraphs(aot_model, aot_inputs):
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
# See [Backward Generation Handling]
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
return fn
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholders(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),
)
out._boxed_call = True
return out
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
return cudagraphs(model, inputs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run