mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
# PR This PR supports mutating inputs in cudagraph trees, if these inputs are outputs from previous cudagraph. Please check #121861 for more details. # Note on Optimistic Mutation Check To determine whether applying cudagraph, we need to check input mutations, falling into four categories: a) no mutation, b) mutation on parameters/buffers, c) mutation on cudagraph recorded tensors, d) mutation on non-cudagraph recorded tensors. We can apply cudagraph for type a,b,c but cannot for type d. This input mutation types depends on function, current_node, and inputs. Since `check_for_mutation` is slow, there is a trade-off on making type c or d faster. - To make type d) faster, we want to `check_for_mutation` and call eager function early. However, this adds unnecessary overhead to type a, b, c due to the extra check. - To make type c) faster, we want to skip `check_for_mutation` at the beginning and only `check_for_mutation` before `record_function` for a new function. This removes the overhead of `check_for_mutation` for type a, b, c. However, this adds extra overhead to type d due to `check_invariants` for all children nodes. Instead, we design optimistic mutation check. The assumption is that, given a function and a node, the input mutation types usually remain the same across inputs. So, if we have ever detect a function on a node with type d, we will never detect it as type c. The detailed design is: - [Slow Path] On the first invocation of a function on a node, we run `check_for_mutation` once and cache the input mutation type as `non_cudagraph_managed_mutation[node_id][func_id]`. - [Fast Path] On the subsequent invocations of a function on a node, we skip `check_for_mutation`. For `non_cudagraph_managed_mutation[node_id][func_id]` as true, we directly call eager function. Otherwise, we `check_variants` and call cudagraph function. - [Slow Path] Before `record_function`, we run `check_for_mutation` again. **Q1: Would there be overhead for type a,b,c,d?** A: No. We only check input mutation types for the first invocation of a function on a node. **Q2: If a function happens to be type c during the first invocation on a node, could we detect it as type d in the future?** A: Yes. This is done by `check_invariants` and guarantees the correctness. **Q3: If a function happens to be type d during the first invocation on a node, could it still be recognized as type c in the future?** A: No. But this should happen rarely according to our assumption. In the rare case that it happens, there would not be any correctness issues and the performance is the same as the eager (or inductor optimized) function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123231 Approved by: https://github.com/eellison
234 lines
7.4 KiB
Python
234 lines
7.4 KiB
Python
# mypy: ignore-errors
|
|
|
|
import functools
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
from torch._dynamo.backends.debugging import boxed_nop
|
|
from torch._inductor.cudagraph_utils import (
|
|
BoxedDeviceIndex,
|
|
check_multiple_devices_or_any_cpu_nodes,
|
|
get_placeholders,
|
|
)
|
|
from torch._inductor.utils import (
|
|
BoxedBool,
|
|
count_tangents,
|
|
has_incompatible_cudagraph_ops,
|
|
num_fw_fixed_arguments,
|
|
output_node,
|
|
)
|
|
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from .registry import register_backend
|
|
|
|
perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
|
|
|
|
|
def find_input_mutations(g):
|
|
def meta_fk(meta):
|
|
return meta["val"] if "val" in meta else meta["fake_result"]
|
|
|
|
inputs = defaultdict(set)
|
|
input_idx = 0
|
|
mutated_inputs = set()
|
|
for n in g.nodes:
|
|
if n.op == "placeholder":
|
|
if isinstance(meta_fk(n.meta), torch.Tensor):
|
|
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
|
|
input_idx += 1
|
|
elif n.op == "call_function":
|
|
if n.target is operator.getitem:
|
|
continue
|
|
schema = n.target._schema
|
|
for i, arg in enumerate(schema.arguments):
|
|
if i < len(n.args):
|
|
argument = n.args[i]
|
|
else:
|
|
if arg.name not in n.kwargs:
|
|
continue
|
|
argument = n.kwargs[arg.name]
|
|
mut_arg = False
|
|
if arg.alias_info:
|
|
if arg.alias_info.is_write:
|
|
mut_arg = True
|
|
if mut_arg:
|
|
# TODO: not correct for args that contain tensors in a struct
|
|
# like list
|
|
mutated_inputs |= inputs[
|
|
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
|
|
]
|
|
|
|
# TODO: error on unrecognized nodes
|
|
return mutated_inputs
|
|
|
|
|
|
def get_device_node_mapping(gm: torch.fx.GraphModule):
|
|
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
|
|
for n in gm.graph.nodes:
|
|
t = n.meta.get("val", None)
|
|
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
|
|
device_node_mapping[t.device] = n
|
|
return device_node_mapping
|
|
|
|
|
|
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
|
|
if skip := check_multiple_devices_or_any_cpu_nodes(
|
|
get_device_node_mapping(aot_model)
|
|
):
|
|
return skip
|
|
|
|
if has_incompatible_cudagraph_ops(aot_model):
|
|
return "skipping cudagraphs due to incompatible op"
|
|
|
|
return None
|
|
|
|
|
|
def get_device_index(gm) -> int:
|
|
device = next(iter(get_device_node_mapping(gm)))
|
|
assert device.type == "cuda"
|
|
return device.index
|
|
|
|
|
|
def get_stack_traces(gm) -> List[Optional[str]]:
|
|
output = output_node(gm)
|
|
assert len(output.args) == 1
|
|
return [
|
|
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
|
for arg in output.args[0]
|
|
]
|
|
|
|
|
|
def cudagraphs(dynamo_model, dynamo_inputs):
|
|
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
|
|
|
do_cudagraphs = BoxedBool(True)
|
|
boxed_device_index = BoxedDeviceIndex(None)
|
|
|
|
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
|
|
interp = boxed_nop(aot_model, aot_inputs)
|
|
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
|
|
if skip_msg := check_for_skip(aot_model, fixed):
|
|
BoxedBool.disable(do_cudagraphs)
|
|
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
|
|
return interp
|
|
|
|
boxed_device_index.set(get_device_index(aot_model))
|
|
out = cudagraphify_impl(
|
|
interp,
|
|
aot_inputs,
|
|
range(fixed),
|
|
device_index=boxed_device_index.value,
|
|
is_backward=False,
|
|
is_inference=False,
|
|
stack_traces=get_stack_traces(aot_model),
|
|
placeholders=get_placeholders(aot_model.graph),
|
|
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
|
)
|
|
out._boxed_call = True
|
|
return out
|
|
|
|
def backward_cudagraphs(aot_model, aot_inputs):
|
|
interp = boxed_nop(aot_model, aot_inputs)
|
|
if not do_cudagraphs:
|
|
return aot_model
|
|
|
|
fixed = count_tangents(aot_model)
|
|
if skip_msg := check_for_skip(aot_model, fixed):
|
|
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
|
|
|
|
# See [Backward Generation Handling]
|
|
manager = torch._inductor.cudagraph_trees.get_manager(
|
|
boxed_device_index.value, create_if_none_exists=False
|
|
)
|
|
assert manager is not None
|
|
|
|
def fn(inputs):
|
|
manager.set_to_running_backward()
|
|
return aot_model(inputs)
|
|
|
|
fn._boxed_call = True
|
|
return fn
|
|
|
|
out = cudagraphify_impl(
|
|
interp,
|
|
aot_inputs,
|
|
range(fixed),
|
|
device_index=get_device_index(aot_model),
|
|
is_backward=True,
|
|
is_inference=False,
|
|
stack_traces=get_stack_traces(aot_model),
|
|
placeholders=get_placeholders(aot_model.graph),
|
|
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
|
)
|
|
out._boxed_call = True
|
|
return out
|
|
|
|
aot_cudagraphs = aot_autograd(
|
|
fw_compiler=forward_cudagraphs,
|
|
bw_compiler=backward_cudagraphs,
|
|
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
|
|
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
|
|
)
|
|
return aot_cudagraphs(dynamo_model, dynamo_inputs)
|
|
|
|
|
|
class CudagraphsBackend:
|
|
compiler_name = "cudagraphs"
|
|
|
|
@staticmethod
|
|
def reset():
|
|
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
|
|
|
reset_cudagraph_trees()
|
|
|
|
@staticmethod
|
|
def __call__(model, inputs):
|
|
return cudagraphs(model, inputs)
|
|
|
|
|
|
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
|
|
# for debugging and can serve as a perf baseline.
|
|
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
|
|
|
|
|
|
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
|
"""This isn't registered as a backend, but is used in some benchmarks"""
|
|
assert isinstance(inputs, (list, tuple))
|
|
if copy_inputs:
|
|
static_inputs = [torch.zeros_like(x) for x in inputs]
|
|
else:
|
|
static_inputs = list(inputs)
|
|
|
|
# warmup
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
model(*inputs)
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# record
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
static_outputs = model(*static_inputs)
|
|
if not isinstance(static_outputs, (list, tuple)):
|
|
static_outputs = (static_outputs,)
|
|
|
|
def run(*new_inputs):
|
|
assert len(static_inputs) == len(new_inputs)
|
|
if copy_inputs:
|
|
for dst, src in zip(static_inputs, new_inputs):
|
|
dst.copy_(src)
|
|
graph.replay()
|
|
if copy_outputs:
|
|
return [x.clone() for x in static_outputs]
|
|
else:
|
|
return static_outputs
|
|
|
|
return run
|